# GAN Image Example: Tiny ImageNet

* _Author_: Sebastian Nowozin (Sebastian.Nowozin@microsoft.com)
* _Date_: 16th July 2018

In [2]:
import math
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from tensorboardX import SummaryWriter

In [4]:
import torchvision
import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
tinyimagenet = torchvision.datasets.ImageFolder('/mnt/nfs/users/senowozi/data/tiny-imagenet/tiny-imagenet-200/train',
    transform=transform)
len(tinyimagenet)

100000

## GAN Model

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable

In [8]:
class ConjugateDualFunction:
    def __init__(self, divergence_name):
        self.divergence_name = divergence_name

    def T(self, v):
        """Compute T(v) repersentation
        
        Arguments
        v -- The scalar output (full real number range) of the discriminator
        """
        if self.divergence_name == "kl":
            return v
        elif self.divergence_name == "klrev":
            return -F.exp(v)
        elif self.divergence_name == "pearson":
            return v
        elif self.divergence_name == "neyman":
            return 1.0 - F.exp(v)
        elif self.divergence_name == "hellinger":
            return 1.0 - F.exp(v)
        elif self.divergence_name == "jensen":
            return math.log(2.0) - F.softplus(-v)
        elif self.divergence_name == "gan":
            return -F.softplus(-v)
        else:
            raise ValueError("Unknown divergence name in t function.")

    def fstarT(self, v):
        """Compute the f^*(T(v)) representation
        
        Arguments
        v -- The scalar output of the variational function neural network.
        """
        if self.divergence_name == "kl":
            return torch.exp(v - 1.0)
        elif self.divergence_name == "klrev":
            return -1.0 - v
        elif self.divergence_name == "pearson":
            return 0.25*v*v + v
        elif self.divergence_name == "neyman":
            return 2.0 - 2.0*F.exp(0.5*v)
        elif self.divergence_name == "hellinger":
            return F.exp(-v) - 1.0
        elif self.divergence_name == "jensen":
            return F.softplus(v) - math.log(2.0)
        elif self.divergence_name == "gan":
            return F.softplus(v)
        else:
            raise ValueError("Unknown divergence name in fstar_t function.")

### DCGAN architecture

In [9]:
class DCGANGenerator(nn.Module):
    def __init__(self, nrand):
        super(DCGANGenerator, self).__init__()
        self.lin1 = nn.Linear(nrand, 4*4*512)
        init.xavier_uniform_(self.lin1.weight, gain=0.1)
        self.lin1bn = nn.BatchNorm1d(4*4*512)
        self.dc1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.dc1bn = nn.BatchNorm2d(256)
        self.dc2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.dc2bn = nn.BatchNorm2d(128)
        self.dc3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.dc3bn = nn.BatchNorm2d(64)
        self.dc4a = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.dc4abn = nn.BatchNorm2d(32)
        self.dc4b = nn.Conv2d(32, 3, 3, stride=1, padding=1)

    def forward(self, z):
        h = F.relu(self.lin1bn(self.lin1(z)))
        h = torch.reshape(h, (-1, 512, 4, 4))

        # deconv stack
        h = F.relu(self.dc1bn(self.dc1(h)))
        h = F.relu(self.dc2bn(self.dc2(h)))
        h = F.relu(self.dc3bn(self.dc3(h)))
        h = F.relu(self.dc4abn(self.dc4a(h)))
        x = self.dc4b(h)

        return x

class DCGANDiscriminator(nn.Module):
    def __init__(self):
        super(DCGANDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.conv1bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.conv2bn = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.conv3bn = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        self.conv4bn = nn.BatchNorm2d(512)
        self.lin1 = nn.Linear(4*4*512, 512)
        self.lin1bn = nn.BatchNorm1d(512)
        self.lin2 = nn.Linear(512, 1)

    def forward(self, x):
        h = F.elu(self.conv1bn(self.conv1(x)))
        h = F.elu(self.conv2bn(self.conv2(h)))
        h = F.elu(self.conv3bn(self.conv3(h)))
        h = F.elu(self.conv4bn(self.conv4(h)))
        h = torch.reshape(h, (-1, 4*4*512))

        h = F.elu(self.lin1bn(self.lin1(h)))
        v = self.lin2(h)

        return v

In [10]:
class FGANLearningObjective(nn.Module):
    def __init__(self, gen, disc, divergence_name="gan", gamma=10.0):
        super(FGANLearningObjective, self).__init__()
        self.gen = gen
        self.disc = disc
        self.conj = ConjugateDualFunction(divergence_name)
        self.gammahalf = 0.5*gamma

    def forward(self, xreal, zmodel):
        # Real data
        vreal = self.disc(xreal)    # Real data discriminator output
        Treal = self.conj.T(vreal)  # Mapped to T-space

        # Model data
        xmodel = self.gen(zmodel)   # Map noise to data
        vmodel = self.disc(xmodel)  # Model data discriminator output
        fstar_Tmodel = self.conj.fstarT(vmodel)   # Mapped to f^*(T)

        # Compute generator loss
        loss_gen = -fstar_Tmodel.mean()

        # Compute discriminator loss (negation because we minimize)
        loss_disc = fstar_Tmodel.mean() - Treal.mean()

        # Compute gradient penalty as per (Mescheder et al., ICML 2018)
        if self.gammahalf > 0.0:
            batchsize = xreal.size(0)
            grad_pd = torch.autograd.grad(Treal.sum(), xreal,
                create_graph=True, only_inputs=True)[0]
            grad_pd_norm2 = grad_pd.pow(2)
            grad_pd_norm2 = grad_pd_norm2.view(batchsize, -1).sum(1)
            gradient_penalty = self.gammahalf * grad_pd_norm2.mean()
            loss_disc += gradient_penalty

        return loss_gen, loss_disc

In [14]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=3)

In [27]:
nrand = 128
gen = DCGANGenerator(nrand)
disc = DCGANDiscriminator()
fgan = FGANLearningObjective(gen, disc, "gan", gamma=1000.0)
fgan = fgan.to(device)
#fgan = torch.nn.DataParallel(fgan)
#fgan.to(device)

In [28]:
batchsize = 32
#optimizer_gen = optim.Adam(fgan.gen.parameters(), lr=1.0e-2)
#optimizer_disc = optim.Adam(fgan.disc.parameters(), lr=1.0e-2)
optimizer_gen = optim.RMSprop(fgan.gen.parameters(), lr=1.0e-2)
optimizer_disc = optim.RMSprop(fgan.disc.parameters(), lr=1.0e-2)

trainloader = torch.utils.data.DataLoader(tinyimagenet,
    batch_size=batchsize, shuffle=True, num_workers=8)

In [29]:
writer = SummaryWriter(log_dir="runs/TinyImageNet", comment="f-GAN-JS")

nepochs = 500
niter = 0
for epoch in range(nepochs):
    zmodel = Variable(torch.rand((batchsize,nrand), device=device))
    xmodel = fgan.gen(zmodel)
    xmodelimg = vutils.make_grid(xmodel,
        normalize=True, scale_each=True)
    writer.add_image('Generated', xmodelimg, global_step=niter)
    #writer.file_writer.flush()

    for i, data in enumerate(trainloader, 0):
        niter += 1
        imgs, labels = data

        fgan.zero_grad()

        # Generate real data (from known distribution) and noise
        xreal = Variable(imgs.to(device), requires_grad=True)
        zmodel = Variable(torch.rand((batchsize,nrand), device=device))

        loss_gen, loss_disc = fgan(xreal, zmodel)
        writer.add_scalar('obj/disc', loss_disc, niter)
        writer.add_scalar('obj/gen', loss_gen, niter)
        if i == 0:
            print("epoch %d  iter %d  obj(D) %.4f  obj(G) %.4f" % (epoch, niter, loss_disc, loss_gen))

        fgan.gen.zero_grad()
        loss_gen.backward(retain_graph=True)
        optimizer_gen.step()

        fgan.disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()

        #if epoch == 0 and i == 0:
        #    writer.add_graph(fgan, input_to_model=(xreal,zmodel))

writer.export_scalars_to_json("./all_scalars.json")
writer.close()

epoch 0  iter 1  obj(D) 6.6641  obj(G) -0.7253
epoch 1  iter 3126  obj(D) 1.3824  obj(G) -0.6950
epoch 2  iter 6251  obj(D) 1.3864  obj(G) -0.7571
epoch 3  iter 9376  obj(D) 1.3964  obj(G) -0.6927
epoch 4  iter 12501  obj(D) 1.3835  obj(G) -0.6992
epoch 5  iter 15626  obj(D) 1.3909  obj(G) -0.6978
epoch 6  iter 18751  obj(D) 1.3808  obj(G) -0.7022
epoch 7  iter 21876  obj(D) 1.3772  obj(G) -0.6898
epoch 8  iter 25001  obj(D) 1.3821  obj(G) -0.6843
epoch 9  iter 28126  obj(D) 1.3848  obj(G) -0.6885
epoch 10  iter 31251  obj(D) 1.3720  obj(G) -0.7093
epoch 11  iter 34376  obj(D) 1.3639  obj(G) -0.7423
epoch 12  iter 37501  obj(D) 1.3854  obj(G) -0.7045
epoch 13  iter 40626  obj(D) 1.3804  obj(G) -0.6734
epoch 14  iter 43751  obj(D) 1.3822  obj(G) -0.6956
epoch 15  iter 46876  obj(D) 1.2399  obj(G) -0.5686
epoch 16  iter 50001  obj(D) 1.3795  obj(G) -0.6881
epoch 17  iter 53126  obj(D) 1.3893  obj(G) -0.6956
epoch 18  iter 56251  obj(D) 1.3801  obj(G) -0.7175
epoch 19  iter 59376  obj(D) 

epoch 155  iter 484376  obj(D) 1.3916  obj(G) -0.7053
epoch 156  iter 487501  obj(D) 1.3853  obj(G) -0.6937
epoch 157  iter 490626  obj(D) 1.3850  obj(G) -0.6920
epoch 158  iter 493751  obj(D) 1.4139  obj(G) -0.7013
epoch 159  iter 496876  obj(D) 1.3871  obj(G) -0.6921
epoch 160  iter 500001  obj(D) 1.3808  obj(G) -0.6955
epoch 161  iter 503126  obj(D) 1.3880  obj(G) -0.6924
epoch 162  iter 506251  obj(D) 1.3847  obj(G) -0.6926
epoch 163  iter 509376  obj(D) 7.9985  obj(G) -0.6821
epoch 164  iter 512501  obj(D) 1.3861  obj(G) -0.7004
epoch 165  iter 515626  obj(D) 1.3977  obj(G) -0.6809
epoch 166  iter 518751  obj(D) 1.3870  obj(G) -0.7168
epoch 167  iter 521876  obj(D) 1.3765  obj(G) -0.6810
epoch 168  iter 525001  obj(D) 1.3856  obj(G) -0.7167
epoch 169  iter 528126  obj(D) 1.3851  obj(G) -0.6951
epoch 170  iter 531251  obj(D) 1.3837  obj(G) -0.6769
epoch 171  iter 534376  obj(D) 1.3867  obj(G) -0.6723
epoch 172  iter 537501  obj(D) 1.3810  obj(G) -0.6850
epoch 173  iter 540626  obj(

epoch 307  iter 959376  obj(D) 1.3871  obj(G) -0.6589
epoch 308  iter 962501  obj(D) 1.3866  obj(G) -0.7000
epoch 309  iter 965626  obj(D) 1.3862  obj(G) -0.6971
epoch 310  iter 968751  obj(D) 1.3867  obj(G) -0.6789
epoch 311  iter 971876  obj(D) 1.3859  obj(G) -0.7353
epoch 312  iter 975001  obj(D) 1.3831  obj(G) -0.6911
epoch 313  iter 978126  obj(D) 1.3772  obj(G) -0.6866
epoch 314  iter 981251  obj(D) 1.3879  obj(G) -0.6867
epoch 315  iter 984376  obj(D) 1.3837  obj(G) -0.6998
epoch 316  iter 987501  obj(D) 1.3834  obj(G) -0.6968
epoch 317  iter 990626  obj(D) 1.3822  obj(G) -0.6732
epoch 318  iter 993751  obj(D) 1.4178  obj(G) -0.8882
epoch 319  iter 996876  obj(D) 1.3888  obj(G) -0.6926
epoch 320  iter 1000001  obj(D) 1.3848  obj(G) -0.6926
epoch 321  iter 1003126  obj(D) 1.3867  obj(G) -0.6753
epoch 322  iter 1006251  obj(D) 1.3968  obj(G) -0.8031
epoch 323  iter 1009376  obj(D) 1.3890  obj(G) -0.6987
epoch 324  iter 1012501  obj(D) 1.3861  obj(G) -0.7090
epoch 325  iter 1015626

epoch 457  iter 1428126  obj(D) 1.3777  obj(G) -0.6071
epoch 458  iter 1431251  obj(D) 0.5282  obj(G) -0.4509
epoch 459  iter 1434376  obj(D) 1.4685  obj(G) -0.9778
epoch 460  iter 1437501  obj(D) 1.3892  obj(G) -0.7618
epoch 461  iter 1440626  obj(D) 0.6763  obj(G) -0.5024
epoch 462  iter 1443751  obj(D) 0.0066  obj(G) -0.0037
epoch 463  iter 1446876  obj(D) 1.4182  obj(G) -0.6457
epoch 464  iter 1450001  obj(D) 1.3844  obj(G) -0.6889
epoch 465  iter 1453126  obj(D) 1.3892  obj(G) -0.6332
epoch 466  iter 1456251  obj(D) 1.3869  obj(G) -0.7106
epoch 467  iter 1459376  obj(D) 1.4039  obj(G) -0.5696
epoch 468  iter 1462501  obj(D) 1.3905  obj(G) -0.6350
epoch 469  iter 1465626  obj(D) 10.2421  obj(G) -10.2394
epoch 470  iter 1468751  obj(D) 1.3954  obj(G) -0.6356
epoch 471  iter 1471876  obj(D) 1.3997  obj(G) -0.8028
epoch 472  iter 1475001  obj(D) 1.3940  obj(G) -0.7747
epoch 473  iter 1478126  obj(D) 1.3951  obj(G) -0.6236
epoch 474  iter 1481251  obj(D) 1.3854  obj(G) -0.6955
epoch 47