# DCGAN on CIFAR-10

In this assignment, you are to train a DCGAN on CIFAR-10. Below you may find the data loader and the generator and discriminator networks. The assignment consists of the following tasks:

**1)**  Compare the two ways to update generator discussed in the original [paper](https://arxiv.org/abs/1406.2661) (page 3, above the figure). The first is two minimise $\mathbb E_z \log (1 - D(G(z))$ and the second is to maximise $\mathbb E_z \log D(G(z))$.

To do this fix a small number of training epochs (around 5) and plot the values of generator and discriminator losses for both cases. How do the two graphs compare? As a sanity check, plot the fake samples at the end of each epoch.

**2)** Train a GAN with the second generator update and a more significant number of training epochs. To track the training progress plot fake samples at the end of each epoch. How do the resulting samples compare with the real data samples visually?

**3)** To estimate the generalization properties of the generator qualitatively, plot interpolations between randomly generated samples. To do this, pick two random noise vectors, connect them with a line and decode the points on the line using the generator.

In [0]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from torchvision import transforms

In [0]:
data_folder = './data'
batch_size = 64
image_size = 64
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

n_channels = 3
n_feature_maps = 64
z_dim = 128

# for the data to be compatible with the DCGAN architecture we will resize it to 64*64
dataset = torchvision.datasets.CIFAR10(
    root=data_folder,
    download=True,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=1)

In [0]:
def get_dcgan_generator(z_dim, n_feature_maps=64, n_channels=3):
    layers = [
        nn.ConvTranspose2d(z_dim, n_feature_maps * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(n_feature_maps * 8),
        nn.ReLU(),
        # n_features * 8 x 4 x 4
        nn.ConvTranspose2d(n_feature_maps * 8, n_feature_maps * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_feature_maps * 4),
        nn.ReLU(),
        # n_features * 4 x 8 x 8
        nn.ConvTranspose2d(n_feature_maps * 4, n_feature_maps * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_feature_maps * 2),
        nn.ReLU(),
        # n_features * 2 x 16 x 16
        nn.ConvTranspose2d(n_feature_maps * 2, n_feature_maps, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_feature_maps),
        nn.ReLU(),
        # n_features x 32 x 32
        nn.ConvTranspose2d(n_feature_maps, n_channels, 4, 2, 1, bias=False),
        nn.Tanh()]
        # n_channels x 64 x 64
    return nn.Sequential(*layers)
    
def get_dcgan_discriminator(n_feature_maps=64, n_channels=3):
    layers = [
        nn.Conv2d(n_channels, n_feature_maps, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2),
        # n_feature_maps x 32 x 32
        nn.Conv2d(n_feature_maps, n_feature_maps * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_feature_maps * 2),
        nn.LeakyReLU(0.2),
        # n_feature_maps * 2 x 16 x 16
        nn.Conv2d(n_feature_maps * 2, n_feature_maps * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_feature_maps * 4),
        nn.LeakyReLU(0.2),
        # n_feature_maps * 4 x 8 x 8
        nn.Conv2d(n_feature_maps * 4, n_feature_maps * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(n_feature_maps * 8),
        nn.LeakyReLU(0.2, inplace=True),
        # n_feature_maps * 8 x 4 x 4
        nn.Conv2d(n_feature_maps * 8, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()]
    return nn.Sequential(*layers)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
generator = get_dcgan_generator(z_dim).to(device)
generator.apply(weights_init)
discriminator = get_dcgan_discriminator().to(device)
discriminator.apply(weights_init)

In [0]:
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002,
                         betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(),
                         lr=0.0002, betas=(0.5, 0.999))

In [0]:
n_epochs = 1

In [0]:
for epoch in range(n_epochs):
    for i, data in enumerate(dataloader, 0):
        # update discriminator
        discriminator.zero_grad()
        real_data = data[0].to(device)
        d_real_scores = discriminator(real_data)
        
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_samples = generator(noise)
        d_fake_scores = discriminator(fake_samples)
        # TODO
        
        # update generator
        generator.zero_grad()
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_samples = generator(noise)
        d_fake_scores = discriminator(fake_samples)
        # TODO