# Deep Conditional Generative Adversarial Networks

In [None]:
from __future__ import print_function
import os
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon import nn, utils
from mxnet import autograd
import numpy as np

## Set training parameters

In [None]:
epochs = 1
batch_size = 256
latent_z_size = 100

width = 176
height = 220
mmean=113.48657
mstd=73.67449*2

use_gpu = True
ctx = mx.gpu() if use_gpu else mx.cpu()

lr = 0.0002
beta1 = 0.5

## CelebA Dataset

In [None]:
augs = mx.image.CreateAugmenter(data_shape=(3, height, width),mean=nd.array([mmean,mmean,mmean]),std=nd.array([mstd,mstd,mstd]))
augs.append(mx.image.ResizeAug(64))

In [None]:
train_data = mx.image.ImageIter(
    path_imgrec = 'dataset/celeba_train.rec',
    path_imgidx = 'dataset/celeba_train.idx',
    path_imglist = 'dataset/celeba_train.lst',
    data_shape = (3, 80, 64),
    label_width = 40,
    batch_size = batch_size,
    shuffle = True,
    aug_list=augs
)

In [None]:
def visualize(img_arr):
    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')

for i in range(4):
    plt.subplot(1,4,i+1)
    visualize(train_data.next().data[0][i + 10])
plt.show()

## The networks

In [None]:
# build the generator
nc = 3
ngf = 64
netG = nn.Sequential()
with netG.name_scope():
    # input is Z, going into a convolution
    netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 4 x 4
    netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, (0,1), use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 10 x 8
    netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 20 x 16
    netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 40 x 32
    netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
    netG.add(nn.Activation('tanh'))
    # state size. (nc) x 64 x 64

# build the discriminator
ndf = 64
netD = nn.Sequential()
with netD.name_scope():
    # input is (nc) x 80 x 64
    netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
    netD.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 40 x 32

netD2 = nn.Sequential()
with netD2.name_scope():
    netD2.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
    netD2.add(nn.BatchNorm())
    netD2.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 20 x 16
    netD2.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
    netD2.add(nn.BatchNorm())
    netD2.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 10 x 8
    netD2.add(nn.Conv2D(ndf * 8, 4, 2, (0,1), use_bias=False))
    netD2.add(nn.BatchNorm())
    netD2.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 4 x 4
    netD2.add(nn.Conv2D(1, 4, 1, 0, use_bias=False))

## Setup Loss Function and Optimizer

In [None]:
# loss
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

# initialize the generator and the discriminator
netG.initialize(mx.init.Normal(0.02), ctx=ctx)
netD.initialize(mx.init.Normal(0.02), ctx=ctx)
netD2.initialize(mx.init.Normal(0.02), ctx=ctx)

# trainer for the generator and the discriminator
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})

## Training Loop

In [None]:
from datetime import datetime
import time
import logging

real_label = nd.ones((batch_size,), ctx=ctx)
fake_label = nd.zeros((batch_size,),ctx=ctx)

def facc(label, pred):
    pred = pred.ravel()
    label = label.ravel()
    return ((pred > 0.5) == label).mean()
metric = mx.metric.CustomMetric(facc)

stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)

f = open("/home/mcy/Dropbox/1,biometrics/condGAN_experiments/condGAN.log", "a")

for epoch in range(epochs):
    tic = time.time()
    btic = time.time()
    train_data.reset()
    iter = 0
    errD_total = 0
    errG_total = 0
    for batch in train_data:
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        data = batch.data[0].as_in_context(ctx)
        # right attribute label y 
        label = mx.nd.expand_dims(mx.nd.expand_dims(batch.label[0].as_in_context(ctx),axis=2),axis=3)
        # wrong attribute label y_hat
        label_hat = mx.nd.random_normal(0, 1, shape=(batch_size, 40, 32), ctx=ctx)
        label_hat = (label_hat >= 0)*2 - 1
        # latent vector z
        latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx)
        latent_z = mx.nd.concat(latent_z,label,dim=1)
        # right discriminator labels
        label_d = mx.nd.dot(mx.nd.expand_dims(batch.label[0].as_in_context(ctx),axis=2),mx.nd.ones([1,32],ctx))
        label_d = mx.nd.reshape(label_d,shape=[batch_size,1,40,32])
        # wrong discriminator labels
        label_d_hat = mx.nd.reshape(label_hat,shape=[batch_size,1,40,32])
        with autograd.record():
            # train with real image
            # real images right attributes
            output = netD(data)
            output = netD2(mx.nd.concat(output,label_d)).reshape((-1, 1))
            errD_real = loss(output, real_label)
            metric.update([real_label,], [output,])
            
            # real images wrong attributes
            output = netD(data)
            output = netD2(mx.nd.concat(output,label_d_hat)).reshape((-1, 1))
            errD_real = (errD_real + loss(output, fake_label))/2
            metric.update([fake_label,], [output,])

            # train with fake image
            # fake images right attributes
            fake = netG(latent_z)
            output = netD(fake.detach())
            output = netD2(mx.nd.concat(output,label_d)).reshape((-1, 1))
            errD_fake = loss(output, fake_label)
            errD = errD_real + errD_fake
            errD.backward()
            metric.update([fake_label,], [output,])
        errD_total = errD_total + nd.mean(errD).asscalar()

        trainerD.step(batch.data[0].shape[0])
        trainerD2.step(batch.data[0].shape[0])

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        with autograd.record():
            fake = netG(latent_z)
            output = netD(fake)
            output = netD2(mx.nd.concat(output,label_d)).reshape((-1, 1))
            errG = loss(output, real_label)
            errG.backward()
        errG_total = errG_total + nd.mean(errG).asscalar()

        trainerG.step(batch.data[0].shape[0])

        # Print log infomation every ten batches
        # if iter % 10 == 0:
        #     name, acc = metric.get()
        #     logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
        #     logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' 
        #              %(errD_total, 
        #                errD_total, acc, iter, epoch))
        iter = iter + 1
        btic = time.time()
    name, acc = metric.get()
    metric.reset()
    
    f.write('discriminator loss = %f, generator loss = %f, binary training acc = %f at epoch %d \n' 
                 %(errD_total, errG_total, acc, epoch))

    # logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
    # logging.info('time: %f' % (time.time() - tic))

    # Visualize one generated image for each epoch
    fake_img = fake[0]
    visualize(fake_img)
    plt.savefig('samples_per_epoch/samples-%d.png'%(epoch),dpi=150)
    plt.show()
f.close()

## Results
Given a trained generator, we can generate some images of faces.

In [None]:
num_image = 8
# right attribute label y 
label = mx.nd.zeros([1,40,1,1],ctx)
label[:,0,:,:]=1	#5_o_Clock_Shadow
label[:,1,:,:]=-1	#Arched_Eyebrows
label[:,2,:,:]=-1	#Attractive
label[:,3,:,:]=1	#Bags_Under_Eyes
label[:,4,:,:]=-1	#Bald
label[:,5,:,:]=-1	#Bangs
label[:,6,:,:]=-1	#Big_Lips
label[:,7,:,:]=1	#Big_Nose
label[:,8,:,:]=1	#Black_Hair
label[:,9,:,:]=-1	#Blond_Hair
label[:,10,:,:]=-1	#Blurry
label[:,11,:,:]=-1	#Brown_Hair
label[:,12,:,:]=-1	#Bushy_Eyebrows
label[:,13,:,:]=-1	#Chubby
label[:,14,:,:]=-1	#Double_Chin
label[:,15,:,:]=1	#Eyeglasses
label[:,16,:,:]=1	#Goatee
label[:,17,:,:]=-1	#Gray_Hair
label[:,18,:,:]=-1	#Heavy_Makeup
label[:,19,:,:]=-1	#High_Cheekbones
label[:,20,:,:]=1	#Male
label[:,21,:,:]=-1	#Mouth_Slightly_Open
label[:,22,:,:]=-1	#Mustache
label[:,23,:,:]=-1	#Narrow_Eyes
label[:,24,:,:]=-1	#No_Beard
label[:,25,:,:]=-1	#Oval_Face
label[:,26,:,:]=-1	#Pale_Skin
label[:,27,:,:]=-1	#Pointy_Nose
label[:,28,:,:]=-1	#Receding_Hairline
label[:,29,:,:]=-1	#Rosy_Cheeks
label[:,30,:,:]=1	#Sideburns
label[:,31,:,:]=-1	#Smiling
label[:,32,:,:]=1	#Straight_Hair
label[:,33,:,:]=-1	#Wavy_Hair
label[:,34,:,:]=-1	#Wearing_Earrings
label[:,35,:,:]=-1	#Wearing_Hat
label[:,36,:,:]=-1	#Wearing_Lipstick
label[:,37,:,:]=-1	#Wearing_Necklace
label[:,38,:,:]=-1	#Wearing_Necktie
label[:,39,:,:]=1	#Young

In [None]:
for i in range(num_image):
    # latent vector z
    latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
    latent_z = mx.nd.concat(latent_z,label,dim=1)
    img = netG(latent_z)
    plt.subplot(2,4,i+1)
    visualize(img[0])
    plt.savefig('/home/mcy/Dropbox/1,biometrics/condGAN_experiments/samples.png',dpi=150)
plt.show()

We can also interpolate along the manifold between images by interpolating linearly between points in the latent space and visualizing the corresponding images. We can see that small changes in the latent space results in smooth changes in generated images.

In [None]:
num_image = 12
label = mx.nd.random_normal(0, 1, shape=(1, 40), ctx=ctx)
label = (label >= 0)*2 - 1
label = mx.nd.expand_dims(mx.nd.expand_dims(label,axis=2),axis=3)
latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
latent_z = mx.nd.concat(latent_z,label,dim=1)
step = 0.05
for i in range(num_image):
    img = netG(latent_z)
    plt.subplot(3,4,i+1)
    visualize(img[0])
    latent_z += 0.05
    plt.savefig('/home/mcy/Dropbox/1,biometrics/condGAN_experiments/interpolation_samples.png',dpi=150)
plt.show()

In [None]:
netG.save_parameters("params/condGAN-generator.params")
netD.save_parameters("params/condGAN-discriminator.params")
netD2.save_parameters("params/condGAN-discriminator2.params")