# Energy-based Generative Adversarial Networks

In [our introduction to generative adversarial networks (GANs)](./gan-intro.ipynb), 
we introduced the basic ideas behind how GANs work.
We showed that they can draw samples from some simple, easy-to-sample distribution,
like a uniform or normal distribution, 
and transform them into samples that appear to match the distribution of some data set. 
And while our example of matching a 2D Gaussian distribution got the point across, it's not especially exciting.

In this notebook, we'll demonstrate how you can use GANs 
to generate photorealistic images. 
We'll be basing our models on the deep convolutional GANs introduced in [this paper](https://arxiv.org/abs/1511.06434). 
We'll borrow the convolutional architecture that have proven so successful for discriminative computer vision problems
and show how via GANs, they can be leveraged to generate photorealistic images. 

In this tutorial, concentrate on the [LWF Face Dataset](http://vis-www.cs.umass.edu/lfw/), 
which contains roughly 13000 images of faces. 
By the end of the tutorial, you'll know how to generate photo-realistic images of your own, given any dataset of images. First, we'll the the preliminaries out of the way.


In [1]:
from __future__ import print_function
import os
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon import nn, utils
from mxnet import autograd
import numpy as np

## Set training parameters

In [2]:
epochs = 2 # Set low by default for tests, set higher when you actually run this code.
batch_size = 64
latent_z_size = 100

use_gpu = False
ctx = mx.gpu() if use_gpu else mx.cpu()

lr = 0.0002
beta1 = 0.5
margin = 80
pt_weight = 0.1

## Download and preprocess the CelebA Face Dataset

In [3]:
#celeba_url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADSNUu0bseoCKuxuI5ZeTl1a/Img?dl=0&preview=img_align_celeba.zip'
#data_path = 'img_align_celeba'
#if not os.path.exists(data_path):
#    os.makedirs(data_path)
#    data_file = utils.download(celeba_url)
#    with tarfile.open(data_file) as tar:
#        tar.extractall(path=data_path)
        
lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz'
data_path = 'lfw_dataset'
if not os.path.exists(data_path):
    os.makedirs(data_path)
    data_file = utils.download(lfw_url)
    with tarfile.open(data_file) as tar:
        tar.extractall(path=data_path)

First, we resize images to size $64\times64$. Then, we normalize all pixel values to the $[-1, 1]$ range.

In [4]:
target_wd = 64
target_ht = 64
img_list = []

def transform(data, target_wd, target_ht):
    # resize to target_wd * target_ht
    data = mx.image.imresize(data, target_wd, target_ht)
    # transpose from (target_wd, target_ht, 3) 
    # to (3, target_wd, target_ht)
    data = nd.transpose(data, (2,0,1))
    # normalize to [-1, 1]
    data = data.astype(np.float32)/127.5 - 1
    # if image is greyscale, repeat 3 times to get RGB image.
    if data.shape[0] == 1:
        data = nd.tile(data, (3, 1, 1))
    return data.reshape((1,) + data.shape)

for path, _, fnames in os.walk(data_path):
    for fname in fnames:
        if not fname.endswith('.jpg'):
            continue
        img = os.path.join(path, fname)
        img_arr = mx.image.imread(img)
        img_arr = transform(img_arr, target_wd, target_ht)
        img_list.append(img_arr)
train_data = mx.io.NDArrayIter(data=nd.concatenate(img_list), batch_size=batch_size)

Visualize 4 images:

In [5]:
def visualize(img_arr):
    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')

for i in range(4):
    plt.subplot(1,4,i+1)
    visualize(img_list[i + 10][0])
plt.show()

## Defining the networks

The core to the DCGAN architecture uses a standard CNN architecture on the discriminative model. For the generator,
convolutions are replaced with upconvolutions, so the representation at each layer of the generator is actually successively larger, as it mapes from a low-dimensional latent vector onto a high-dimensional image.

* Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).

* Use batch normalization in both the generator and the discriminator.

* Remove fully connected hidden layers for deeper architectures.

* Use ReLU activation in generator for all layers except for the output, which uses Tanh.

* Use LeakyReLU activation in the discriminator for all layers.

![](../img/dcgan.png "DCGAN Architecture")

In [6]:
# build the generator
class Generator(nn.HybridBlock):
    def __init__(self, nc=3, ngf=64):
        super(Generator, self).__init__()
        with self.name_scope():
            self.model = nn.HybridSequential()
            with self.model.name_scope():
                # input is Z, going into a convolution
                self.model.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
                self.model.add(nn.BatchNorm())
                self.model.add(nn.Activation('relu'))
                # state size. (ngf*8) x 4 x 4
                self.model.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
                self.model.add(nn.BatchNorm())
                self.model.add(nn.Activation('relu'))
                # state size. (ngf*4) x 8 x 8
                self.model.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
                self.model.add(nn.BatchNorm())
                self.model.add(nn.Activation('relu'))
                # state size. (ngf*2) x 16 x 16
                self.model.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
                self.model.add(nn.BatchNorm())
                self.model.add(nn.Activation('relu'))
                # state size. (ngf) x 32 x 32
                self.model.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
                self.model.add(nn.Activation('tanh'))
                # state size. (nc) x 64 x 64
                
    def hybrid_forward(self, F, x):
        return self.model(x)

# Pull away term
def pt(F, encode_out):
    norm_out = F.L2Normalization(encode_out)
    similarity = F.dot(norm_out, norm_out, transpose_b=True)
    pt_loss = (F.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))

# build the discriminator
class Discriminator(nn.HybridBlock):
    def __init__(self, nc=3, ndf=64, use_pt=True):
        super(Discriminator, self).__init__()
        self.use_pt = use_pt
        with self.name_scope():
            self.encoder = nn.HybridSequential()
            with self.encoder.name_scope():
                # input is (nc) x 64 x 64
                self.encoder.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
                self.encoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf) x 32 x 32
                self.encoder.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
                self.encoder.add(nn.BatchNorm())
                self.encoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf * 2) x 16 x 16
                self.encoder.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
                self.encoder.add(nn.BatchNorm())
                self.encoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf * 4) x 8 x 8
                self.encoder.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
                self.encoder.add(nn.BatchNorm())
                self.encoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf * 8) x 4 x 4
                
            self.decoder = nn.HybridSequential()
            with self.decoder.name_scope():
                self.decoder.add(nn.Conv2DTranspose(ndf * 4, 4, 2, 1, use_bias=False))
                self.decoder.add(nn.BatchNorm())
                self.decoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf * 4) x 8 x 8
                self.decoder.add(nn.Conv2DTranspose(ndf * 2, 4, 2, 1, use_bias=False))
                self.decoder.add(nn.BatchNorm())
                self.decoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf * 2) x 16 x 16
                self.decoder.add(nn.Conv2DTranspose(ndf, 4, 2, 1, use_bias=False))
                self.decoder.add(nn.BatchNorm())
                self.decoder.add(nn.LeakyReLU(0.2))
                # state size. (ndf) x 32 x 32
                self.decoder.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
                # state size. nc x 64 x 64
                
    def hybrid_forward(self, F, x):
        enc_out = self.encoder(x)
        dec_out = self.decoder(enc_out) 
        
        pt_loss = 0
        if self.use_pt:
            pt(F, enc_out)

        return dec_out, pt   

## Setup Loss Function and Optimizer
We use binary cross-entropy as our loss function and use the Adam optimizer. We initialize the network's parameters by sampling from a normal distribution.

In [None]:
def param_init(param):
    if param.name.find('conv') != -1:
        if param.name.find('weight') != -1:
            param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
        else:
            param.initialize(init=mx.init.Zero(), ctx=ctx)
    elif param.name.find('batchnorm') != -1:
        param.initialize(init=mx.init.Zero(), ctx=ctx)

def network_init(net):
    for param in net.collect_params().values():
        param_init(param)

    # Avoid defer initialization
    net(train_data.next().data[0].as_in_context(ctx))

    # Initialize gamma from normal distribution with mean 1 and std 0.02
    for param in net.collect_params().values():
        if param.name.find('batchnorm') != -1 and param.name.find('gamma') != -1:
            param.set_data(nd.random_normal(1, 0.02, param.data().shape))

netG = Generator()
netD = Discriminator()

network_init(netG)
network_init(netD)

# loss
loss = gluon.loss.L2Loss()

# trainer for the generator and the discriminator
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})

## Training Loop
We recommend thst you use a GPU for training this model. After a few epochs, we can see human-face-like images are generated.

In [None]:
from datetime import datetime
import time
import logging

metric = mx.metric.MSE()

stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)

#netG.hybridize()
#netD.hybridize()

for epoch in range(epochs):
    tic = time.time()
    btic = time.time()
    train_data.reset()
    iter = 0
    for batch in train_data:
        ############################
        # (1) Update D network: minimize MSE(D(x), x) + margin - MSE(D(G(z)), z)
        ###########################
        data = batch.data[0].as_in_context(ctx)
        latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx)

        with autograd.record():
            # train with real image
            output, _ = netD(data) 
            errD_real = loss(output, data)
            metric.update([data,], [output,])

            # train with fake image
            fake = netG(latent_z)
            output, _ = netD(fake.detach())
            errD_fake = loss(output, fake.detach())
            errD = errD_real + mx.nd.clip(margin - errD_fake, 0, margin)
            errD.backward()
            metric.update([fake,], [output,])

        trainerD.step(batch.data[0].shape[0])

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        with autograd.record():
            fake = netG(latent_z)
            output, pt = netD(fake)
            errG = loss(output, fake) + pt_loss * pt
            errG.backward()

        trainerG.step(batch.data[0].shape[0])

        # Print log infomation every ten batches
        if iter % 10 == 0:
            name, mse = metric.get()
            logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
            logging.info('discriminator loss = %f, generator loss = %f, mean square error = %f at iter %d epoch %d' 
                     %(nd.mean(errD).asscalar(), 
                       nd.mean(errG).asscalar(), mse, iter, epoch))
        iter = iter + 1
        btic = time.time()

    name, acc = metric.get()
    metric.reset()
    logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
    logging.info('time: %f' % (time.time() - tic))

    # Visualize one generated image for each epoch
    fake_img = fake[0]
    visualize(fake_img)
    # plt.show()

## Results
Given a trained generator, we can generate some images of faces.

In [None]:
num_image = 8
for i in range(num_image):
    latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
    img = netG(latent_z)
    plt.subplot(2,4,i+1)
    visualize(img[0])
plt.show()

We can also interpolate along the manifold between images by interpolating linearly between points in the latent space and visualizing the corresponding images. We can see that small changes in the latent space results in smooth changes in generated images.

In [None]:
num_image = 12
latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
step = 0.05
for i in range(num_image):
    img = netG(latent_z)
    plt.subplot(3,4,i+1)
    visualize(img[0])
    latent_z += 0.05
plt.show()

For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)