In [1]:
from keras.datasets import cifar10
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K

import matplotlib.pyplot as plt

import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [4]:
class CCGAN():
    def __init__(self):
        self.img_rows = 32 
        self.img_cols = 32
        self.mask_height = 10
        self.mask_width = 10
        self.channels = 3
        self.num_classes = 2
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], 
            loss_weights=[0.5, 0.5],
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss=['binary_crossentropy'], 
            optimizer=optimizer)

        # The generator takes noise as input and generates imgs
        masked_img = Input(shape=self.img_shape)
        gen_img = self.generator(masked_img)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid, _ = self.discriminator(gen_img)

        # The combined model  (stacked generator and discriminator) takes
        # masked_img as input => generates images => determines validity 
        self.combined = Model(masked_img , [gen_img, valid])
        self.combined.compile(loss=['mse', 'binary_crossentropy'],
            loss_weights=[0.999, 0.001],
            optimizer=optimizer)


    def build_generator(self):

        
        model = Sequential()

        # Encoder
        model.add(Conv2D(64, kernel_size=4, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(256, kernel_size=4, strides=2, padding="same"))
        model.add(Activation('relu'))

        # Decoder
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(Activation('relu'))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(Activation('relu'))
        model.add(UpSampling2D())
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation('tanh'))

        model.summary()

        masked_img = Input(shape=self.img_shape)
        img = model(masked_img)

        return Model(masked_img, img)

    def build_discriminator(self):
        
        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, input_shape=self.img_shape, padding="same"))
        model.add(Activation('relu'))

        model.add(MaxPooling2D())

        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation('relu'))

        model.add(MaxPooling2D())

        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))

        model.add(MaxPooling2D())

        model.add(Conv2D(256, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(256, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        
        model.add(MaxPooling2D())

        model.add(Flatten())

        model.summary()

        img = Input(shape=self.img_shape)
        features = model(img)

        valid = Dense(1, activation="sigmoid")(features)
        label = Dense(self.num_classes+1, activation="softmax")(features)

        return Model(img, [valid, label])

    def mask_randomly(self, imgs):
        y1 = np.random.randint(0, self.img_rows - self.mask_height, imgs.shape[0])
        y2 = y1 + self.mask_height
        x1 = np.random.randint(0, self.img_rows - self.mask_width, imgs.shape[0])
        x2 = x1 + self.mask_width

        masked_imgs = np.empty_like(imgs)
        for i, img in enumerate(imgs):
            masked_img = img.copy()
            _y1, _y2, _x1, _x2 = y1[i], y2[i], x1[i], x2[i], 
            masked_img[_y1:_y2, _x1:_x2, :] = 0
            masked_imgs[i] = masked_img

        return masked_imgs



    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset

        (X_train, y_train), (X_test, y_test) = cifar10.load_data()

        X_train = np.vstack((X_train, X_test))
        y_train = np.vstack((y_train, y_test))

        # Extract dogs and cats
        X_cats = X_train[(y_train == 3).flatten()]
        y_cats = y_train[y_train == 3]
        X_dogs = X_train[(y_train == 5).flatten()]
        y_dogs = y_train[y_train == 5]
        X_train = np.vstack((X_cats, X_dogs))
        y_train = np.vstack((y_cats, y_dogs))

        # Change labels to 0 and 1
        y_train[y_train == 3] = 0
        y_train[y_train == 5] = 1

        # Rescale -1 to 1
        X_train = X_train / 255
        X_train = 2 * X_train - 1
        y_train = y_train.reshape(-1, 1)

        half_batch = int(batch_size / 2)

        # Class weights:
        # To balance the difference in occurences of digit class labels. 
        # 50% of labels that the discriminator trains on are 'fake'.
        # Weight = 1 / frequency
        cw1 = {0: 1, 1: 1}
        cw2 = {i: self.num_classes / half_batch for i in range(self.num_classes)}
        cw2[self.num_classes] = 1 / half_batch
        class_weights = [cw1, cw2]

        for epoch in range(epochs):


            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]
            labels = y_train[idx]

            masked_imgs = self.mask_randomly(imgs)
            
            # Generate a half batch of new images
            gen_imgs = self.generator.predict(masked_imgs)

            valid = np.ones((half_batch, 1))
            fake = np.zeros((half_batch, 1))

            labels = to_categorical(labels, num_classes=self.num_classes+1)
            fake_labels = to_categorical(np.full((half_batch, 1), self.num_classes), num_classes=self.num_classes+1)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake, fake_labels], class_weight=class_weights)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            
            masked_imgs = self.mask_randomly(imgs)

            # Generator wants the discriminator to label the generated images as valid
            valid = np.ones((batch_size, 1))
            
            # Train the generator
            g_loss = self.combined.train_on_batch(masked_imgs, [imgs, valid])

            # Plot the progress
            print ("%d [D loss: %f, acc: %.2f%%, op_acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[3], 100*d_loss[4], g_loss[0], g_loss[1]))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                # Select a random half batch of images
                idx = np.random.randint(0, X_train.shape[0], 6)
                imgs = X_train[idx]
                self.save_imgs(epoch, imgs)
                self.save_model()

    def save_imgs(self, epoch, imgs):
        r, c = 3, 6
        
        masked_imgs = self.mask_randomly(imgs)
        gen_imgs = self.generator.predict(masked_imgs)

        imgs = 0.5 * imgs + 0.5
        masked_imgs = 0.5 * masked_imgs + 0.5
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        for i in range(c):
            axs[0,i].imshow(imgs[i, :,:])
            axs[0,i].axis('off')
            axs[1,i].imshow(masked_imgs[i, :,:])
            axs[1,i].axis('off')
            axs[2,i].imshow(gen_imgs[i, :,:])
            axs[2,i].axis('off')
        fig.savefig("images/cifar_%d.png" % epoch)
        plt.close()

    def save_model(self):

        def save(model, model_name):
            model_path = "ccgan/saved_model/%s.json" % model_name
            weights_path = "ccgan/saved_model/%s_weights.hdf5" % model_name
            options = {"file_arch": model_path, 
                        "file_weight": weights_path}
            json_string = model.to_json()
            open(options['file_arch'], 'w').write(json_string)
            model.save_weights(options['file_weight'])

        save(self.generator, "ccgan_generator")
        save(self.discriminator, "ccgan_discriminator")


In [6]:
ccgan = CCGAN()
ccgan.train(epochs=20000, batch_size=32, save_interval=50)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_25 (Conv2D)           (None, 32, 32, 32)        896       
_________________________________________________________________
activation_25 (Activation)   (None, 32, 32, 32)        0         
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 16, 16, 64)        18496     
_________________________________________________________________
activation_26 (Activation)   (None, 16, 16, 64)        0         
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_27 (Conv2D)           (None, 8, 8, 128)         73856     
__________

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.403407, acc: 0.00%, op_acc: 18.75%] [G loss: 0.283719, mse: 0.283314]
1 [D loss: 0.379286, acc: 50.00%, op_acc: 15.62%] [G loss: 0.231170, mse: 0.230731]
2 [D loss: 0.370387, acc: 50.00%, op_acc: 31.25%] [G loss: 0.183978, mse: 0.183545]
3 [D loss: 0.385339, acc: 50.00%, op_acc: 25.00%] [G loss: 0.222561, mse: 0.222231]
4 [D loss: 0.351296, acc: 50.00%, op_acc: 21.88%] [G loss: 0.211636, mse: 0.211300]
5 [D loss: 0.383958, acc: 50.00%, op_acc: 31.25%] [G loss: 0.220926, mse: 0.220606]
6 [D loss: 0.366975, acc: 50.00%, op_acc: 25.00%] [G loss: 0.165207, mse: 0.164823]
7 [D loss: 0.364516, acc: 50.00%, op_acc: 15.62%] [G loss: 0.138129, mse: 0.137672]
8 [D loss: 0.348829, acc: 50.00%, op_acc: 31.25%] [G loss: 0.135600, mse: 0.135143]
9 [D loss: 0.347891, acc: 50.00%, op_acc: 31.25%] [G loss: 0.142456, mse: 0.142029]
10 [D loss: 0.355096, acc: 50.00%, op_acc: 31.25%] [G loss: 0.120759, mse: 0.120301]
11 [D loss: 0.344686, acc: 50.00%, op_acc: 21.88%] [G loss: 0.119382, mse: 0

97 [D loss: 0.055120, acc: 96.88%, op_acc: 65.62%] [G loss: 0.075783, mse: 0.070841]
98 [D loss: 0.030246, acc: 100.00%, op_acc: 78.12%] [G loss: 0.074588, mse: 0.069689]
99 [D loss: 0.030445, acc: 100.00%, op_acc: 78.12%] [G loss: 0.081512, mse: 0.076047]
100 [D loss: 0.027202, acc: 100.00%, op_acc: 71.88%] [G loss: 0.073164, mse: 0.067431]
101 [D loss: 0.022555, acc: 100.00%, op_acc: 84.38%] [G loss: 0.065548, mse: 0.059171]
102 [D loss: 0.029827, acc: 100.00%, op_acc: 68.75%] [G loss: 0.071756, mse: 0.064604]
103 [D loss: 0.026669, acc: 100.00%, op_acc: 71.88%] [G loss: 0.074969, mse: 0.067980]
104 [D loss: 0.021192, acc: 100.00%, op_acc: 81.25%] [G loss: 0.072390, mse: 0.065490]
105 [D loss: 0.056361, acc: 96.88%, op_acc: 75.00%] [G loss: 0.068037, mse: 0.062051]
106 [D loss: 0.021015, acc: 100.00%, op_acc: 75.00%] [G loss: 0.066955, mse: 0.060254]
107 [D loss: 0.025359, acc: 100.00%, op_acc: 71.88%] [G loss: 0.072915, mse: 0.065748]
108 [D loss: 0.023541, acc: 100.00%, op_acc: 75.

193 [D loss: 0.076683, acc: 96.88%, op_acc: 71.88%] [G loss: 0.052065, mse: 0.047983]
194 [D loss: 0.084460, acc: 93.75%, op_acc: 68.75%] [G loss: 0.049455, mse: 0.045266]
195 [D loss: 0.039746, acc: 100.00%, op_acc: 75.00%] [G loss: 0.054676, mse: 0.049917]
196 [D loss: 0.027963, acc: 100.00%, op_acc: 71.88%] [G loss: 0.061090, mse: 0.055905]
197 [D loss: 0.091243, acc: 93.75%, op_acc: 84.38%] [G loss: 0.058782, mse: 0.055678]
198 [D loss: 0.090251, acc: 90.62%, op_acc: 71.88%] [G loss: 0.063800, mse: 0.059528]
199 [D loss: 0.025555, acc: 100.00%, op_acc: 87.50%] [G loss: 0.049518, mse: 0.044453]
200 [D loss: 0.036531, acc: 100.00%, op_acc: 71.88%] [G loss: 0.056119, mse: 0.050894]
201 [D loss: 0.025236, acc: 100.00%, op_acc: 81.25%] [G loss: 0.052912, mse: 0.047554]
202 [D loss: 0.030587, acc: 100.00%, op_acc: 75.00%] [G loss: 0.050462, mse: 0.045040]
203 [D loss: 0.025909, acc: 100.00%, op_acc: 81.25%] [G loss: 0.049798, mse: 0.044702]
204 [D loss: 0.029167, acc: 100.00%, op_acc: 71

289 [D loss: 0.028826, acc: 100.00%, op_acc: 90.62%] [G loss: 0.049286, mse: 0.044805]
290 [D loss: 0.032355, acc: 100.00%, op_acc: 75.00%] [G loss: 0.044913, mse: 0.040208]
291 [D loss: 0.026044, acc: 100.00%, op_acc: 78.12%] [G loss: 0.050975, mse: 0.046335]
292 [D loss: 0.037362, acc: 100.00%, op_acc: 81.25%] [G loss: 0.053847, mse: 0.047563]
293 [D loss: 0.026894, acc: 100.00%, op_acc: 78.12%] [G loss: 0.051074, mse: 0.044217]
294 [D loss: 0.062195, acc: 96.88%, op_acc: 78.12%] [G loss: 0.052492, mse: 0.046930]
295 [D loss: 0.023755, acc: 100.00%, op_acc: 71.88%] [G loss: 0.046722, mse: 0.042056]
296 [D loss: 0.027822, acc: 100.00%, op_acc: 75.00%] [G loss: 0.052077, mse: 0.046662]
297 [D loss: 0.028266, acc: 100.00%, op_acc: 71.88%] [G loss: 0.045843, mse: 0.040063]
298 [D loss: 0.026155, acc: 100.00%, op_acc: 78.12%] [G loss: 0.050723, mse: 0.044287]
299 [D loss: 0.026456, acc: 100.00%, op_acc: 75.00%] [G loss: 0.050785, mse: 0.044085]
300 [D loss: 0.027735, acc: 100.00%, op_acc:

385 [D loss: 0.050204, acc: 96.88%, op_acc: 71.88%] [G loss: 0.040770, mse: 0.035735]
386 [D loss: 0.079529, acc: 96.88%, op_acc: 75.00%] [G loss: 0.046321, mse: 0.041431]
387 [D loss: 0.056313, acc: 96.88%, op_acc: 71.88%] [G loss: 0.044855, mse: 0.040403]
388 [D loss: 0.048814, acc: 96.88%, op_acc: 78.12%] [G loss: 0.038950, mse: 0.034290]
389 [D loss: 0.028537, acc: 100.00%, op_acc: 68.75%] [G loss: 0.044433, mse: 0.039195]
390 [D loss: 0.139002, acc: 93.75%, op_acc: 68.75%] [G loss: 0.036710, mse: 0.033424]
391 [D loss: 0.073109, acc: 93.75%, op_acc: 62.50%] [G loss: 0.042099, mse: 0.037892]
392 [D loss: 0.035428, acc: 100.00%, op_acc: 84.38%] [G loss: 0.035886, mse: 0.031166]
393 [D loss: 0.102137, acc: 93.75%, op_acc: 62.50%] [G loss: 0.040192, mse: 0.036079]
394 [D loss: 0.108432, acc: 93.75%, op_acc: 81.25%] [G loss: 0.039404, mse: 0.035167]
395 [D loss: 0.037579, acc: 100.00%, op_acc: 71.88%] [G loss: 0.045452, mse: 0.040021]
396 [D loss: 0.023682, acc: 100.00%, op_acc: 81.25%

481 [D loss: 0.031099, acc: 100.00%, op_acc: 75.00%] [G loss: 0.042434, mse: 0.036164]
482 [D loss: 0.078891, acc: 93.75%, op_acc: 71.88%] [G loss: 0.048927, mse: 0.044492]
483 [D loss: 0.052710, acc: 96.88%, op_acc: 65.62%] [G loss: 0.040849, mse: 0.035188]
484 [D loss: 0.025394, acc: 100.00%, op_acc: 75.00%] [G loss: 0.040264, mse: 0.034380]
485 [D loss: 0.143939, acc: 93.75%, op_acc: 78.12%] [G loss: 0.038994, mse: 0.032779]
486 [D loss: 0.020558, acc: 100.00%, op_acc: 81.25%] [G loss: 0.040882, mse: 0.034543]
487 [D loss: 0.065803, acc: 96.88%, op_acc: 78.12%] [G loss: 0.039240, mse: 0.034301]
488 [D loss: 0.036803, acc: 100.00%, op_acc: 87.50%] [G loss: 0.045328, mse: 0.039063]
489 [D loss: 0.079768, acc: 96.88%, op_acc: 59.38%] [G loss: 0.037518, mse: 0.032738]
490 [D loss: 0.096419, acc: 90.62%, op_acc: 75.00%] [G loss: 0.036763, mse: 0.032022]
491 [D loss: 0.050524, acc: 96.88%, op_acc: 84.38%] [G loss: 0.038472, mse: 0.032670]
492 [D loss: 0.035273, acc: 100.00%, op_acc: 75.00

577 [D loss: 0.034698, acc: 100.00%, op_acc: 68.75%] [G loss: 0.037153, mse: 0.032350]
578 [D loss: 0.054550, acc: 96.88%, op_acc: 71.88%] [G loss: 0.041525, mse: 0.036799]
579 [D loss: 0.035864, acc: 100.00%, op_acc: 71.88%] [G loss: 0.039578, mse: 0.034533]
580 [D loss: 0.096588, acc: 96.88%, op_acc: 68.75%] [G loss: 0.033517, mse: 0.029506]
581 [D loss: 0.058470, acc: 96.88%, op_acc: 78.12%] [G loss: 0.037035, mse: 0.033224]
582 [D loss: 0.063327, acc: 96.88%, op_acc: 75.00%] [G loss: 0.039113, mse: 0.034113]
583 [D loss: 0.089238, acc: 96.88%, op_acc: 71.88%] [G loss: 0.036630, mse: 0.031633]
584 [D loss: 0.024707, acc: 100.00%, op_acc: 75.00%] [G loss: 0.038281, mse: 0.033075]
585 [D loss: 0.033385, acc: 100.00%, op_acc: 75.00%] [G loss: 0.040282, mse: 0.035780]
586 [D loss: 0.040436, acc: 100.00%, op_acc: 81.25%] [G loss: 0.047007, mse: 0.042369]
587 [D loss: 0.020962, acc: 100.00%, op_acc: 84.38%] [G loss: 0.047780, mse: 0.042171]
588 [D loss: 0.036076, acc: 100.00%, op_acc: 68.

KeyboardInterrupt: 