In [19]:
from keras.layers import Input
from keras.models import Model
from sub_net import *

In [33]:
class CycleGAN():
    
    def __init__(self, img_shape=(128, 128, 3), g_filter=32, d_filter=64, lamdba_cycle=10.0, lambda_id=1.0):
        # Input shape
        self.img_rows, self.img_cols, self.channels = self.img_shape = img_shape

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.generator_filter, self.discriminator_filter = 32, 64

        # Loss weights
        self.lambda_cycle, self.lambda_id = lamdba_cycle, lambda_id

#         optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminators
        self.discriminator_A = self.build_discriminator()
        self.discriminator_A.name = 'discriminator_A'
        self.discriminator_A.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
        self.discriminator_B = self.build_discriminator()
        self.discriminator_B.name = 'discriminator_B'
        self.discriminator_B.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])

        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------

        # Build the generators
        self.generator_A_to_B = self.build_generator()
        self.generator_A_to_B.name = 'generator_A_to_B'
        self.generator_B_to_A = self.build_generator()
        self.generator_B_to_A.name = 'generator_B_to_A'

        # Input images from both domains
        img_A = Input(shape=self.img_shape, name='input_A')
        img_B = Input(shape=self.img_shape, name='input_B')

        # Translate images to the other domain
        fake_B = self.generator_A_to_B(img_A)
        fake_A = self.generator_B_to_A(img_B)
        # Translate images back to original domain
        reconstr_A = self.generator_B_to_A(fake_B)
        reconstr_B = self.generator_A_to_B(fake_A)
        # Identity mapping of images
        img_A_id = self.generator_B_to_A(img_A)
        img_B_id = self.generator_A_to_B(img_B)

        # For the combined model we will only train the generators
        self.discriminator_A.trainable = False
        self.discriminator_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.discriminator_A(fake_A)
        valid_B = self.discriminator_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[valid_A, valid_B,
                                       reconstr_A, reconstr_B,
                                       img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[1, 1,
                                          self.lambda_cycle, self.lambda_cycle,
                                          self.lambda_id, self.lambda_id ],
                            optimizer='rmsprop')
    
  
    def build_generator(self):
        return UNET_G(self.img_rows, num_generator_filter=self.generator_filter)
    
    def build_discriminator(self):
        return BASIC_D(self.channels, self.discriminator_filter)
    
    def train(self, dataloader, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.discriminator_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.discriminator_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.discriminator_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.discriminator_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)


                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            batch_i, self.data_loader.n_batches,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
    
     def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)

        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

In [34]:
model = CycleGAN()

In [31]:
model.combined.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_30 (InputLayer)           (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
input_29 (InputLayer)           (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
generator_B_to_A (Model)        (None, 128, 128, 3)  10465699    input_30[0][0]                   
                                                                 generator_A_to_B[1][0]           
                                                                 input_29[0][0]                   
__________________________________________________________________________________________________
generator_