# GAN Image Example: CIFAR

* _Author_: Sebastian Nowozin (Sebastian.Nowozin@microsoft.com)
* _Date_: 16th July 2018

In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from tensorboardX import SummaryWriter

In [3]:
import torchvision
import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
CIFAR10 = torchvision.datasets.CIFAR10('data-cifar10', train=True,
    download=True, transform=transform)


Files already downloaded and verified


## GAN Model

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable

In [6]:
class ConjugateDualFunction:
    def __init__(self, divergence_name):
        self.divergence_name = divergence_name

    def T(self, v):
        """Compute T(v) repersentation
        
        Arguments
        v -- The scalar output (full real number range) of the discriminator
        """
        if self.divergence_name == "kl":
            return v
        elif self.divergence_name == "klrev":
            return -F.exp(v)
        elif self.divergence_name == "pearson":
            return v
        elif self.divergence_name == "neyman":
            return 1.0 - F.exp(v)
        elif self.divergence_name == "hellinger":
            return 1.0 - F.exp(v)
        elif self.divergence_name == "jensen":
            return math.log(2.0) - F.softplus(-v)
        elif self.divergence_name == "gan":
            return -F.softplus(-v)
        else:
            raise ValueError("Unknown divergence name in t function.")

    def fstarT(self, v):
        """Compute the f^*(T(v)) representation
        
        Arguments
        v -- The scalar output of the variational function neural network.
        """
        if self.divergence_name == "kl":
            return torch.exp(v - 1.0)
        elif self.divergence_name == "klrev":
            return -1.0 - v
        elif self.divergence_name == "pearson":
            return 0.25*v*v + v
        elif self.divergence_name == "neyman":
            return 2.0 - 2.0*F.exp(0.5*v)
        elif self.divergence_name == "hellinger":
            return F.exp(-v) - 1.0
        elif self.divergence_name == "jensen":
            return F.softplus(v) - math.log(2.0)
        elif self.divergence_name == "gan":
            return F.softplus(v)
        else:
            raise ValueError("Unknown divergence name in fstar_t function.")

### DCGAN architecture

In [7]:
class DCGANGenerator(nn.Module):
    def __init__(self, nrand):
        super(DCGANGenerator, self).__init__()
        self.lin1 = nn.Linear(nrand, 4*4*512)
        init.xavier_uniform_(self.lin1.weight, gain=0.1)
        self.lin1bn = nn.BatchNorm1d(4*4*512)
        self.dc1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.dc1bn = nn.BatchNorm2d(256)
        self.dc2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.dc2bn = nn.BatchNorm2d(128)
        self.dc3a = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.dc3abn = nn.BatchNorm2d(64)
        self.dc3b = nn.Conv2d(64, 3, 3, stride=1, padding=1)

    def forward(self, z):
        h = F.relu(self.lin1bn(self.lin1(z)))
        h = torch.reshape(h, (-1, 512, 4, 4))

        # deconv stack
        h = F.relu(self.dc1bn(self.dc1(h)))
        h = F.relu(self.dc2bn(self.dc2(h)))
        h = F.relu(self.dc3abn(self.dc3a(h)))
        x = self.dc3b(h)

        return x

class DCGANDiscriminator(nn.Module):
    def __init__(self):
        super(DCGANDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.conv1bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.conv2bn = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.conv3bn = nn.BatchNorm2d(256)
        self.lin1 = nn.Linear(4*4*256, 512)
        self.lin1bn = nn.BatchNorm1d(512)
        self.lin2 = nn.Linear(512, 1)

    def forward(self, x):
        h = F.elu(self.conv1bn(self.conv1(x)))
        h = F.elu(self.conv2bn(self.conv2(h)))
        h = F.elu(self.conv3bn(self.conv3(h)))
        h = torch.reshape(h, (-1, 4*4*256))

        h = F.elu(self.lin1bn(self.lin1(h)))
        v = self.lin2(h)

        return v

In [29]:
class FGANLearningObjective(nn.Module):
    def __init__(self, gen, disc, divergence_name="gan", gamma=0.01):
        super(FGANLearningObjective, self).__init__()
        self.gen = gen
        self.disc = disc
        self.conj = ConjugateDualFunction(divergence_name)
        self.gammahalf = 0.5*gamma

    def forward(self, xreal, zmodel):
        # Real data
        vreal = self.disc(xreal)    # Real data discriminator output
        Treal = self.conj.T(vreal)  # Mapped to T-space

        # Model data
        xmodel = self.gen(zmodel)   # Map noise to data
        vmodel = self.disc(xmodel)  # Model data discriminator output
        fstar_Tmodel = self.conj.fstarT(vmodel)   # Mapped to f^*(T)

        # Compute generator loss
        loss_gen = -fstar_Tmodel.mean()

        # Compute discriminator loss (negation because we minimize)
        loss_disc = fstar_Tmodel.mean() - Treal.mean()

        # Compute gradient penalty as per (Mescheder et al., ICML 2018)
        if self.gammahalf > 0.0:
            #grad_pd = torch.autograd.grad(Treal.mean(), xreal,
            #    create_graph=True)[0]
            #grad_pd_norm2 = grad_pd.pow(2).sum()
            #gradient_penalty = self.gammahalf * grad_pd_norm2
            #loss_disc += gradient_penalty
            batchsize = xreal.size(0)
            grad_pd = torch.autograd.grad(Treal.sum(), xreal,
                create_graph=True, only_inputs=True)[0]
            grad_pd_norm2 = grad_pd.pow(2)
            grad_pd_norm2 = grad_pd_norm2.view(batchsize, -1).sum(1)
            gradient_penalty = self.gammahalf * grad_pd_norm2.mean()
            loss_disc += gradient_penalty

        return loss_gen, loss_disc

In [30]:
torch.cuda.device_count()

4

In [31]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=2)

In [38]:
nrand = 100
gen = DCGANGenerator(nrand)
disc = DCGANDiscriminator()
fgan = FGANLearningObjective(gen, disc, "gan", gamma=10.0)
fgan = fgan.to(device)
#fgan = torch.nn.DataParallel(fgan)
#fgan.to(device)

In [39]:
batchsize = 64
optimizer_gen = optim.Adam(fgan.gen.parameters(), lr=1.0e-3)
optimizer_disc = optim.Adam(fgan.disc.parameters(), lr=1.0e-3)

trainloader = torch.utils.data.DataLoader(CIFAR10,
    batch_size=batchsize, shuffle=True, num_workers=8)

In [40]:
writer = SummaryWriter(log_dir="runs/CIFAR10", comment="f-GAN-JS")

nepochs = 1000
niter = 0
for epoch in range(nepochs):
    zmodel = Variable(torch.rand((batchsize,nrand), device=device))
    xmodel = fgan.gen(zmodel)
    xmodelimg = vutils.make_grid(xmodel,
        normalize=True, scale_each=True)
    writer.add_image('Generated', xmodelimg, global_step=niter)
    #writer.file_writer.flush()

    for i, data in enumerate(trainloader, 0):
        niter += 1
        imgs, labels = data

        fgan.zero_grad()

        # Generate real data (from known distribution) and noise
        xreal = Variable(imgs.to(device), requires_grad=True)
        zmodel = Variable(torch.rand((batchsize,nrand), device=device))

        loss_gen, loss_disc = fgan(xreal, zmodel)
        writer.add_scalar('obj/disc', loss_disc, niter)
        writer.add_scalar('obj/gen', loss_gen, niter)
        if i == 0:
            print("epoch %d  iter %d  obj(D) %.4f  obj(G) %.4f" % (epoch, niter, loss_disc, loss_gen))

        fgan.gen.zero_grad()
        loss_gen.backward(retain_graph=True)
        optimizer_gen.step()

        fgan.disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()

        #if epoch == 0 and i == 0:
        #    writer.add_graph(fgan, input_to_model=(xreal,zmodel))

writer.export_scalars_to_json("./all_scalars.json")
writer.close()

epoch 0  iter 1  obj(D) 1.4613  obj(G) -0.6857
epoch 1  iter 783  obj(D) 1.3303  obj(G) -0.6956
epoch 2  iter 1565  obj(D) 1.3545  obj(G) -0.6759
epoch 3  iter 2347  obj(D) 1.2330  obj(G) -0.6286
epoch 4  iter 3129  obj(D) 1.3652  obj(G) -0.6625
epoch 5  iter 3911  obj(D) 1.2568  obj(G) -0.6419
epoch 6  iter 4693  obj(D) 0.9561  obj(G) -0.5000
epoch 7  iter 5475  obj(D) 0.9355  obj(G) -0.5433
epoch 8  iter 6257  obj(D) 1.1334  obj(G) -0.6481
epoch 9  iter 7039  obj(D) 1.1495  obj(G) -0.5758
epoch 10  iter 7821  obj(D) 1.3493  obj(G) -0.6366
epoch 11  iter 8603  obj(D) 1.0560  obj(G) -0.6393
epoch 12  iter 9385  obj(D) 0.8939  obj(G) -0.4566
epoch 13  iter 10167  obj(D) 1.3560  obj(G) -0.7240
epoch 14  iter 10949  obj(D) 0.8296  obj(G) -0.4685
epoch 15  iter 11731  obj(D) 1.1155  obj(G) -0.4456
epoch 16  iter 12513  obj(D) 1.1282  obj(G) -0.5707
epoch 17  iter 13295  obj(D) 1.2698  obj(G) -0.8322
epoch 18  iter 14077  obj(D) 1.2235  obj(G) -0.6891
epoch 19  iter 14859  obj(D) 1.5130  ob

epoch 157  iter 122775  obj(D) 0.7876  obj(G) -0.6376
epoch 158  iter 123557  obj(D) 0.8966  obj(G) -0.6107
epoch 159  iter 124339  obj(D) 1.4445  obj(G) -0.8904
epoch 160  iter 125121  obj(D) 0.9392  obj(G) -0.6347
epoch 161  iter 125903  obj(D) 0.7870  obj(G) -0.6551
epoch 162  iter 126685  obj(D) 0.8489  obj(G) -0.5642
epoch 163  iter 127467  obj(D) 0.8951  obj(G) -0.6522
epoch 164  iter 128249  obj(D) 1.0173  obj(G) -0.8929
epoch 165  iter 129031  obj(D) 0.7900  obj(G) -0.6457
epoch 166  iter 129813  obj(D) 0.9947  obj(G) -0.7554
epoch 167  iter 130595  obj(D) 1.0502  obj(G) -0.3953
epoch 168  iter 131377  obj(D) 1.3208  obj(G) -0.5406
epoch 169  iter 132159  obj(D) 0.6329  obj(G) -0.5329
epoch 170  iter 132941  obj(D) 0.8676  obj(G) -0.5637
epoch 171  iter 133723  obj(D) 0.9474  obj(G) -0.4868
epoch 172  iter 134505  obj(D) 1.0741  obj(G) -0.7831
epoch 173  iter 135287  obj(D) 1.0607  obj(G) -0.8684
epoch 174  iter 136069  obj(D) 0.8705  obj(G) -0.7109
epoch 175  iter 136851  obj(

epoch 309  iter 241639  obj(D) 0.8234  obj(G) -0.4799
epoch 310  iter 242421  obj(D) 1.7026  obj(G) -1.6779
epoch 311  iter 243203  obj(D) 0.7477  obj(G) -0.3120
epoch 312  iter 243985  obj(D) 0.8660  obj(G) -0.6774
epoch 313  iter 244767  obj(D) 1.4863  obj(G) -1.4843
epoch 314  iter 245549  obj(D) 1.6165  obj(G) -1.5803
epoch 315  iter 246331  obj(D) 0.8799  obj(G) -0.7179
epoch 316  iter 247113  obj(D) 0.5746  obj(G) -0.5677
epoch 317  iter 247895  obj(D) 0.4309  obj(G) -0.4255
epoch 318  iter 248677  obj(D) 2.1282  obj(G) -1.1178
epoch 319  iter 249459  obj(D) 1.1464  obj(G) -1.0733
epoch 320  iter 250241  obj(D) 0.7623  obj(G) -0.7562
epoch 321  iter 251023  obj(D) 0.4022  obj(G) -0.3480
epoch 322  iter 251805  obj(D) 0.7237  obj(G) -0.4746
epoch 323  iter 252587  obj(D) 0.3996  obj(G) -0.1901
epoch 324  iter 253369  obj(D) 0.4154  obj(G) -0.4075
epoch 325  iter 254151  obj(D) 1.3167  obj(G) -0.7032
epoch 326  iter 254933  obj(D) 0.8807  obj(G) -0.8801
epoch 327  iter 255715  obj(

epoch 461  iter 360503  obj(D) 0.5504  obj(G) -0.4480
epoch 462  iter 361285  obj(D) 0.5996  obj(G) -0.5489
epoch 463  iter 362067  obj(D) 2.1543  obj(G) -2.1239
epoch 464  iter 362849  obj(D) 0.4378  obj(G) -0.1347
epoch 465  iter 363631  obj(D) 0.4703  obj(G) -0.4684
epoch 466  iter 364413  obj(D) 0.5344  obj(G) -0.0708
epoch 467  iter 365195  obj(D) 0.6727  obj(G) -0.3965
epoch 468  iter 365977  obj(D) 0.5988  obj(G) -0.5318
epoch 469  iter 366759  obj(D) 0.6866  obj(G) -0.3885
epoch 470  iter 367541  obj(D) 0.6367  obj(G) -0.6345
epoch 471  iter 368323  obj(D) 0.1515  obj(G) -0.1446
epoch 472  iter 369105  obj(D) 0.7337  obj(G) -0.7254
epoch 473  iter 369887  obj(D) 0.0533  obj(G) -0.0480
epoch 474  iter 370669  obj(D) 0.1121  obj(G) -0.1074
epoch 475  iter 371451  obj(D) 0.3157  obj(G) -0.0502
epoch 476  iter 372233  obj(D) 0.1043  obj(G) -0.1010
epoch 477  iter 373015  obj(D) 0.0571  obj(G) -0.0391
epoch 478  iter 373797  obj(D) 0.0036  obj(G) -0.0016
epoch 479  iter 374579  obj(

epoch 613  iter 479367  obj(D) 0.5515  obj(G) -0.5506
epoch 614  iter 480149  obj(D) 0.0984  obj(G) -0.0981
epoch 615  iter 480931  obj(D) 0.0560  obj(G) -0.0557
epoch 616  iter 481713  obj(D) 0.0520  obj(G) -0.0497
epoch 617  iter 482495  obj(D) 0.0073  obj(G) -0.0065
epoch 618  iter 483277  obj(D) 0.0078  obj(G) -0.0068
epoch 619  iter 484059  obj(D) 0.0013  obj(G) -0.0001
epoch 620  iter 484841  obj(D) 0.0039  obj(G) -0.0034
epoch 621  iter 485623  obj(D) 0.0005  obj(G) -0.0002
epoch 622  iter 486405  obj(D) 0.0001  obj(G) -0.0000
epoch 623  iter 487187  obj(D) 0.0019  obj(G) -0.0017
epoch 624  iter 487969  obj(D) 0.0010  obj(G) -0.0009
epoch 625  iter 488751  obj(D) 0.0017  obj(G) -0.0007
epoch 626  iter 489533  obj(D) 0.0263  obj(G) -0.0247
epoch 627  iter 490315  obj(D) 0.0031  obj(G) -0.0007
epoch 628  iter 491097  obj(D) 0.0013  obj(G) -0.0012
epoch 629  iter 491879  obj(D) 0.0352  obj(G) -0.0116
epoch 630  iter 492661  obj(D) 0.0801  obj(G) -0.0764
epoch 631  iter 493443  obj(

epoch 765  iter 598231  obj(D) 0.0008  obj(G) -0.0007
epoch 766  iter 599013  obj(D) 0.0009  obj(G) -0.0007
epoch 767  iter 599795  obj(D) 0.0002  obj(G) -0.0001
epoch 768  iter 600577  obj(D) 0.0004  obj(G) -0.0000
epoch 769  iter 601359  obj(D) 0.0001  obj(G) -0.0001
epoch 770  iter 602141  obj(D) 0.0003  obj(G) -0.0002
epoch 771  iter 602923  obj(D) 0.0002  obj(G) -0.0001
epoch 772  iter 603705  obj(D) 0.0001  obj(G) -0.0001
epoch 773  iter 604487  obj(D) 0.0001  obj(G) -0.0001
epoch 774  iter 605269  obj(D) 0.0001  obj(G) -0.0000
epoch 775  iter 606051  obj(D) 0.0005  obj(G) -0.0005
epoch 776  iter 606833  obj(D) 0.0002  obj(G) -0.0001
epoch 777  iter 607615  obj(D) 0.0001  obj(G) -0.0000
epoch 778  iter 608397  obj(D) 0.0006  obj(G) -0.0006
epoch 779  iter 609179  obj(D) 0.0003  obj(G) -0.0003
epoch 780  iter 609961  obj(D) 0.0003  obj(G) -0.0002
epoch 781  iter 610743  obj(D) 0.0001  obj(G) -0.0001
epoch 782  iter 611525  obj(D) 0.0001  obj(G) -0.0001
epoch 783  iter 612307  obj(

epoch 917  iter 717095  obj(D) 0.0000  obj(G) -0.0000
epoch 918  iter 717877  obj(D) 0.0051  obj(G) -0.0051
epoch 919  iter 718659  obj(D) 0.0000  obj(G) -0.0000
epoch 920  iter 719441  obj(D) 0.0000  obj(G) -0.0000
epoch 921  iter 720223  obj(D) 0.0000  obj(G) -0.0000
epoch 922  iter 721005  obj(D) 0.0000  obj(G) -0.0000
epoch 923  iter 721787  obj(D) 0.0000  obj(G) -0.0000
epoch 924  iter 722569  obj(D) 0.0001  obj(G) -0.0000
epoch 925  iter 723351  obj(D) 0.0001  obj(G) -0.0000
epoch 926  iter 724133  obj(D) 0.0001  obj(G) -0.0001
epoch 927  iter 724915  obj(D) 0.0246  obj(G) -0.0243
epoch 928  iter 725697  obj(D) 0.0001  obj(G) -0.0000
epoch 929  iter 726479  obj(D) 0.0000  obj(G) -0.0000
epoch 930  iter 727261  obj(D) 0.0009  obj(G) -0.0009
epoch 931  iter 728043  obj(D) 0.0033  obj(G) -0.0033
epoch 932  iter 728825  obj(D) 0.0001  obj(G) -0.0001
epoch 933  iter 729607  obj(D) 0.0007  obj(G) -0.0007
epoch 934  iter 730389  obj(D) 0.0003  obj(G) -0.0003
epoch 935  iter 731171  obj(