# GAN Image Example: Tiny ImageNet

* _Author_: Sebastian Nowozin (Sebastian.Nowozin@microsoft.com)
* _Date_: 16th July 2018

In [4]:
import math
import numpy as np
import matplotlib.pyplot as plt


ModuleNotFoundError: No module named 'tensorboardX'

In [5]:
!pip install tensorboardX

Collecting tensorboardX
  Downloading tensorboardX-2.0-py2.py3-none-any.whl (195 kB)
[K     |████████████████████████████████| 195 kB 2.4 MB/s eta 0:00:01
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.0


In [6]:
from tensorboardX import SummaryWriter

In [7]:
import torchvision
import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [9]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
tinyimagenet = torchvision.datasets.ImageFolder('../data/tiny-imagenet-200/train',
    transform=transform)
len(tinyimagenet)

100000

## GAN Model

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable

In [11]:
class ConjugateDualFunction:
    def __init__(self, divergence_name):
        self.divergence_name = divergence_name

    def T(self, v):
        """Compute T(v) repersentation
        
        Arguments
        v -- The scalar output (full real number range) of the discriminator
        """
        if self.divergence_name == "kl":
            return v
        elif self.divergence_name == "klrev":
            return -F.exp(v)
        elif self.divergence_name == "pearson":
            return v
        elif self.divergence_name == "neyman":
            return 1.0 - F.exp(v)
        elif self.divergence_name == "hellinger":
            return 1.0 - F.exp(v)
        elif self.divergence_name == "jensen":
            return math.log(2.0) - F.softplus(-v)
        elif self.divergence_name == "gan":
            return -F.softplus(-v)
        else:
            raise ValueError("Unknown divergence name in t function.")

    def fstarT(self, v):
        """Compute the f^*(T(v)) representation
        
        Arguments
        v -- The scalar output of the variational function neural network.
        """
        if self.divergence_name == "kl":
            return torch.exp(v - 1.0)
        elif self.divergence_name == "klrev":
            return -1.0 - v
        elif self.divergence_name == "pearson":
            return 0.25*v*v + v
        elif self.divergence_name == "neyman":
            return 2.0 - 2.0*F.exp(0.5*v)
        elif self.divergence_name == "hellinger":
            return F.exp(-v) - 1.0
        elif self.divergence_name == "jensen":
            return F.softplus(v) - math.log(2.0)
        elif self.divergence_name == "gan":
            return F.softplus(v)
        else:
            raise ValueError("Unknown divergence name in fstar_t function.")

### DCGAN architecture

In [12]:
class DCGANGenerator(nn.Module):
    def __init__(self, nrand):
        super(DCGANGenerator, self).__init__()
        self.lin1 = nn.Linear(nrand, 4*4*512)
        init.xavier_uniform_(self.lin1.weight, gain=0.1)
        self.lin1bn = nn.BatchNorm1d(4*4*512)
        self.dc1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.dc1bn = nn.BatchNorm2d(256)
        self.dc2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.dc2bn = nn.BatchNorm2d(128)
        self.dc3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.dc3bn = nn.BatchNorm2d(64)
        self.dc4a = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.dc4abn = nn.BatchNorm2d(32)
        self.dc4b = nn.Conv2d(32, 3, 3, stride=1, padding=1)

    def forward(self, z):
        h = F.relu(self.lin1bn(self.lin1(z)))
        h = torch.reshape(h, (-1, 512, 4, 4))

        # deconv stack
        h = F.relu(self.dc1bn(self.dc1(h)))
        h = F.relu(self.dc2bn(self.dc2(h)))
        h = F.relu(self.dc3bn(self.dc3(h)))
        h = F.relu(self.dc4abn(self.dc4a(h)))
        x = self.dc4b(h)

        return x

class DCGANDiscriminator(nn.Module):
    def __init__(self):
        super(DCGANDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.conv1bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.conv2bn = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.conv3bn = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        self.conv4bn = nn.BatchNorm2d(512)
        self.lin1 = nn.Linear(4*4*512, 512)
        self.lin1bn = nn.BatchNorm1d(512)
        self.lin2 = nn.Linear(512, 1)

    def forward(self, x):
        h = F.elu(self.conv1bn(self.conv1(x)))
        h = F.elu(self.conv2bn(self.conv2(h)))
        h = F.elu(self.conv3bn(self.conv3(h)))
        h = F.elu(self.conv4bn(self.conv4(h)))
        h = torch.reshape(h, (-1, 4*4*512))

        h = F.elu(self.lin1bn(self.lin1(h)))
        v = self.lin2(h)

        return v

In [13]:
class FGANLearningObjective(nn.Module):
    def __init__(self, gen, disc, divergence_name="gan", gamma=10.0):
        super(FGANLearningObjective, self).__init__()
        self.gen = gen
        self.disc = disc
        self.conj = ConjugateDualFunction(divergence_name)
        self.gammahalf = 0.5*gamma

    def forward(self, xreal, zmodel):
        # Real data
        vreal = self.disc(xreal)    # Real data discriminator output
        Treal = self.conj.T(vreal)  # Mapped to T-space

        # Model data
        xmodel = self.gen(zmodel)   # Map noise to data
        vmodel = self.disc(xmodel)  # Model data discriminator output
        fstar_Tmodel = self.conj.fstarT(vmodel)   # Mapped to f^*(T)

        # Compute generator loss
        loss_gen = -fstar_Tmodel.mean()

        # Compute discriminator loss (negation because we minimize)
        loss_disc = fstar_Tmodel.mean() - Treal.mean()

        # Compute gradient penalty as per (Mescheder et al., ICML 2018)
        if self.gammahalf > 0.0:
            batchsize = xreal.size(0)
            grad_pd = torch.autograd.grad(Treal.sum(), xreal,
                create_graph=True, only_inputs=True)[0]
            grad_pd_norm2 = grad_pd.pow(2)
            grad_pd_norm2 = grad_pd_norm2.view(batchsize, -1).sum(1)
            gradient_penalty = self.gammahalf * grad_pd_norm2.mean()
            loss_disc += gradient_penalty

        return loss_gen, loss_disc

In [14]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [15]:
nrand = 128
gen = DCGANGenerator(nrand)
disc = DCGANDiscriminator()
fgan = FGANLearningObjective(gen, disc, "gan", gamma=1000.0)
fgan = fgan.to(device)
#fgan = torch.nn.DataParallel(fgan)
#fgan.to(device)

In [16]:
batchsize = 32
#optimizer_gen = optim.Adam(fgan.gen.parameters(), lr=1.0e-2)
#optimizer_disc = optim.Adam(fgan.disc.parameters(), lr=1.0e-2)
optimizer_gen = optim.RMSprop(fgan.gen.parameters(), lr=1.0e-2)
optimizer_disc = optim.RMSprop(fgan.disc.parameters(), lr=1.0e-2)

trainloader = torch.utils.data.DataLoader(tinyimagenet,
    batch_size=batchsize, shuffle=True, num_workers=8)

In [19]:
writer = SummaryWriter(log_dir="runs/TinyImageNet", comment="f-GAN-JS")

nepochs = 500
niter = 0
for epoch in range(nepochs):
    zmodel = Variable(torch.rand((batchsize,nrand), device=device))
    xmodel = fgan.gen(zmodel)
    xmodelimg = vutils.make_grid(xmodel,
        normalize=True, scale_each=True)
    writer.add_image('Generated', xmodelimg, global_step=niter)
    #writer.file_writer.flush()

    for i, data in enumerate(trainloader, 0):
        niter += 1
        imgs, labels = data

        fgan.zero_grad()

        # Generate real data (from known distribution) and noise
        xreal = Variable(imgs.to(device), requires_grad=True)
        zmodel = Variable(torch.rand((batchsize,nrand), device=device))

        loss_gen, loss_disc = fgan(xreal, zmodel)
        writer.add_scalar('obj/disc', loss_disc, niter)
        writer.add_scalar('obj/gen', loss_gen, niter)
        if i == 0:
            print("epoch %d  iter %d  obj(D) %.4f  obj(G) %.4f" % (epoch, niter, loss_disc, loss_gen))
#         if i %10 == 0:
#         print("POOPOO")
        fgan.gen.zero_grad()
        loss_gen.backward(retain_graph=True)
        optimizer_gen.step()

        fgan.disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()

        #if epoch == 0 and i == 0:
        #    writer.add_graph(fgan, input_to_model=(xreal,zmodel))

writer.export_scalars_to_json("./all_scalars.json")
writer.close()

epoch 0  iter 1  obj(D) 1.7478  obj(G) -0.7721
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO
POOPOO


KeyboardInterrupt: 