# Conditional GAN

In [1]:
# supress warnings
import warnings 
warnings.filterwarnings('ignore')

In [2]:
# import dataset 
from keras.datasets import mnist

(X_train,y_train),(X_test,y_test) = mnist.load_data()
X_train.shape

Using TensorFlow backend.


(60000, 28, 28)

In [3]:
img_width, img_height =28,28
img_channel = 1
img_shape = (img_width, img_height, img_channel)
num_classes = 10
z_dim = 100

# Build Generator

In [4]:
from keras.layers import UpSampling2D, Reshape, Activation, Conv2D, BatchNormalization, LeakyReLU, Input, Flatten, multiply
from keras.layers import Dense, Embedding
from keras.models import Sequential, Model

def build_generator():
    model = Sequential()
    model.add(Dense(128*7*7, activation = 'relu', input_shape = (z_dim, )))
    model.add(Reshape((7,7,128)))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size = 3, padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.02))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size = 3, padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.02))
    model.add(Conv2D(1, kernel_size = 3 , padding='same'))
    model.add(Activation('tanh'))
    
    z = Input(shape= (z_dim,))
    label = Input(shape=(1,), dtype = 'int32')
    
    label_embedding = Embedding(num_classes, z_dim, input_length = 1)(label)
    label_embedding = Flatten()(label_embedding)
    joined = multiply([z, label_embedding])
    
    img = model(joined)
    return Model([z, label], img)

generator = build_generator()
generator.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 100)       1000        input_2[0][0]                    
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 100)          0           embedding_1[0][0]                
__________________________________________________________________________________________________
multiply_1

# Build Discriminator

In [7]:
from keras.layers import Dropout, Concatenate
import numpy as np

def build_discriminator():
    model = Sequential()
    model.add(Conv2D(32, kernel_size = 3, strides = 2, input_shape = (28,28,2), padding = 'same'))
    model.add(LeakyReLU(alpha = 0.02))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, kernel_size = 3, strides = 2, input_shape = (28,28,2), padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.02))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size = 3, strides = 2, input_shape = (28,28,2), padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.02))
    model.add(Dropout(0.25))
    model.add(Conv2D(256, kernel_size = 3, strides = 2, input_shape = (28,28,2), padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.02))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1, activation = 'sigmoid'))
    
    img = Input(shape= (img_shape))
    label = Input(shape= (1,), dtype = 'int32')
    
    label_embedding = Embedding(input_dim = num_classes, output_dim = np.prod(img_shape), input_length = 1)(label)
    label_embedding = Flatten()(label_embedding)
    label_embedding = Reshape(img_shape)(label_embedding)
    
    concat = Concatenate(axis = -1)([img, label_embedding])
    prediction = model(concat)
    return Model([img, label], prediction)

discriminator = build_discriminator()
discriminator.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 1, 784)       7840        input_4[0][0]                    
__________________________________________________________________________________________________
flatten_3 (Flatten)             (None, 784)          0           embedding_2[0][0]                
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
reshape_2 

# Compile and Join Model G/D

In [9]:
from keras.optimizers import Adam

discriminator.compile(loss = 'binary_crossentropy', optimizer = Adam(0.0002, 0.5), metrics = ['accuracy'])

z = Input(shape=(z_dim,))
label = Input(shape= (1,))
img = generator([z,label])

discriminator.trainable = False
prediction = discriminator([img, label])

cgan = Model([z, label], prediction)
cgan.compile(loss= 'binary_crossentropy', optimizer = Adam(0.0002,0.5))

# Build  a function for training G/D

In [13]:
def train(epochs, batch_size, save_interval):
    (X_train,y_train),(X_test, y_test) = mnist.load_data()
    X_train = X_train/127.5 - 1
    X_train = np.expand_dims(X_train, axis=3)
    
    real = np.ones(shape= (batch_size, 1))
    fake = np.zeros(shape = (batch_size,1))
    
    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        img, labels = X_train[idx], y_train[idx]
        
        z = np.random.normal(0,1,size = (batch_size, z_dim))
        gen_img = generator.predict([z,labels])
        
        d_loss_real = discriminator.train_on_batch([img, labels], real)
        d_loss_fake = discriminator.train_on_batch([gen_img, labels], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        z = np.random.randint(0,1, size = (batch_size, z_dim))
        labels = np.random.randint(0, num_classes, batch_size).reshape(-1,1)
        
        g_loss = cgan.train_on_batch([z,labels], real)
        
        if epoch % save_interval == 0:
            print('{} [D_loss: {} , accuracy: {:.2f}] [G_loss: {}]'.format(epoch, d_loss[0], 100*d_loss[1], g_loss))
            save_image(epoch)

In [14]:
import matplotlib.pyplot as plt
def save_image(epoch):
    r, c = 2,5
    z = np.random.normal(0,1,(r*c, z_dim))
    labels = np.arange(0,10).reshape(-1,1)
    gen_image = generator.predict([z,labels])
    gen_image = 0.5 * gen_image + 0.5
    
    fig, axes = plt.subplots(r,c, figsize = (10,10))
    count = 0
    for i in range(r):
        for j in range(c):
            axes[i,j].imshow(gen_image[count,:,:,0],cmap = 'gray')
            axes[i,j].axis('off')
            axes[i,j].set_title("Digit: %d" % labels[count])
            count+=1
    plt.savefig('images/cgan_%d.jpg' % epoch)
    plt.close()

In [None]:
# train network
train(50000, 128, 1000)

0 [D_loss: 1.4379487037658691 , accuracy: 48.05] [G_loss: 0.2636014223098755]
1000 [D_loss: 0.0002288183750351891 , accuracy: 100.00] [G_loss: 0.0005109526682645082]
