The following model is the standard GAN which is part of **Exercise 1**. It is a very simple example and you can improve it by adding convolutions and many other ideas that we talked about if you want. Fill in the missing pieces and train it.


In [None]:
import os
import numpy as np
import math
import multiprocessing
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.optim.optimizer import Optimizer, required
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs("images_gan", exist_ok=True)
os.makedirs("images_cgan", exist_ok=True)

n_epochs = 50                                  #number of epochs of training
batch_size = 64                                #size of the batches
lr = 0.0002                                    #adam: learning rate
b1 = 0.5                                       #adam: decay of first order momentum of gradient
b2 = 0.999                                     #adam: decay of second order momentum of gradient
n_cpu = multiprocessing.cpu_count()            #number of cpu threads to use during batch generation
latent_dim = 100                               #dimensionality of the latent space
img_size = 28                                  #size of each image dimension
channels = 1                                   #number of image channels
sample_interval = 400                          #interval between image samples


img_shape = (channels, img_size, img_size)
torch.manual_seed(42)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


# Loss function
bce_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)
bce_loss.to(device)

# Configure data loader
os.makedirs("./mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        real_imgs =  imgs.to(device)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        z = torch.randn((imgs.shape[0], latent_dim)).to(device)
        gen_imgs = generator(z)
        
        y_pred_fake = discriminator(gen_imgs)
        
        g_loss = bce_loss(y_pred_fake, torch.zeros_like(y_pred_fake))
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        gen_imgs = generator(z)
        
        y_pred_real = discriminator(real_imgs)
        y_pred_fake = discriminator(gen_imgs)

        optimizer_D.zero_grad()

        real_loss = bce_loss(y_pred_real, torch.zeros_like(y_pred_real))
        fake_loss = bce_loss(y_pred_fake, torch.ones_like(y_pred_real))
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        if i % 200 == 0:
            print(
              "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
              % (epoch+1, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            # You can also safe samples in your drive & maybe save your network as well
            save_image(gen_imgs.data[:25], "images_gan/GAN-%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 1/50] [Batch 0/938] [D loss: 0.682676] [G loss: 0.709050]
[Epoch 1/50] [Batch 200/938] [D loss: 0.566229] [G loss: 0.856968]
[Epoch 1/50] [Batch 400/938] [D loss: 0.807519] [G loss: 0.433746]
[Epoch 1/50] [Batch 600/938] [D loss: 0.496054] [G loss: 1.077175]
[Epoch 1/50] [Batch 800/938] [D loss: 0.577621] [G loss: 1.605823]
[Epoch 2/50] [Batch 0/938] [D loss: 0.438962] [G loss: 1.527040]
[Epoch 2/50] [Batch 200/938] [D loss: 0.599387] [G loss: 2.615825]
[Epoch 2/50] [Batch 400/938] [D loss: 0.456024] [G loss: 1.240072]
[Epoch 2/50] [Batch 600/938] [D loss: 0.444105] [G loss: 1.262440]
[Epoch 2/50] [Batch 800/938] [D loss: 0.581124] [G loss: 1.980942]
[Epoch 3/50] [Batch 0/938] [D loss: 0.487133] [G loss: 1.328975]
[Epoch 3/50] [Batch 200/938] [D loss: 0.599038] [G loss: 0.674074]
[Epoch 3/50] [Batch 400/938] [D loss: 0.508071] [G loss: 2.584096]
[Epoch 3/50] [Batch 600/938] [D loss: 0.370612] [G loss: 1.616733]
[Epoch 3/50] [Batch 800/938] [D loss: 0.429238] [G loss: 1.512800]
[

[Epoch 25/50] [Batch 400/938] [D loss: 0.515128] [G loss: 1.697827]
[Epoch 25/50] [Batch 600/938] [D loss: 0.480215] [G loss: 1.701632]
[Epoch 25/50] [Batch 800/938] [D loss: 0.434877] [G loss: 1.327255]
[Epoch 26/50] [Batch 0/938] [D loss: 0.496443] [G loss: 1.193023]
[Epoch 26/50] [Batch 200/938] [D loss: 0.426083] [G loss: 1.162979]
[Epoch 26/50] [Batch 400/938] [D loss: 0.501156] [G loss: 1.425578]
[Epoch 26/50] [Batch 600/938] [D loss: 0.440134] [G loss: 1.322614]
[Epoch 26/50] [Batch 800/938] [D loss: 0.499348] [G loss: 1.816809]
[Epoch 27/50] [Batch 0/938] [D loss: 0.512640] [G loss: 1.465984]
[Epoch 27/50] [Batch 200/938] [D loss: 0.441508] [G loss: 1.600295]
[Epoch 27/50] [Batch 400/938] [D loss: 0.513211] [G loss: 1.407165]
[Epoch 27/50] [Batch 600/938] [D loss: 0.474057] [G loss: 1.356120]
[Epoch 27/50] [Batch 800/938] [D loss: 0.545602] [G loss: 0.793062]
[Epoch 28/50] [Batch 0/938] [D loss: 0.516948] [G loss: 1.860915]
[Epoch 28/50] [Batch 200/938] [D loss: 0.480711] [G lo

In [None]:
z = torch.randn((imgs.shape[0], latent_dim)).to(device)

# Generate a batch of images
with torch.no_grad():
    gen_imgs = generator(z)
    
plt.imshow(gen_imgs.cpu().numpy()[0].reshape((28, 28)), cmap='gray')

In [None]:
class Classifer(nn.Module):
    def __init__(self, in_dims):
        super().__init__()
        self.conv1 = nn.Conv2d(in_dims, 16, kernel_size=(3, 3), padding=1, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1, stride=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1, stride=2)
        self.dense1 = nn.Linear(4 * 4 * 64, 10)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = self.conv3(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 4 * 4 * 64)
        x = self.dense1(x)
        x = torch.softmax(x, dim=1)
        return x
    
clf = Classifer(1).to(device)

clf_optim = torch.optim.Adam(clf.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss().to(device)


for epoch in range(1, 4):
    running_loss = 0
    running_accuracy = 0
    iterations = 0
    for i, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        clf_optim.zero_grad()
        y_pred = clf(x)
        loss = criterion(y_pred, y)
        loss.backward()
        clf_optim.step()
        running_loss += loss.item()
        iterations += 1
        with torch.no_grad():
            accuracy = torch.mean(((torch.argmax(y_pred, 1) == y) * 1).float())
            running_accuracy += accuracy.item()
    loss = running_loss / iterations
    acc = running_accuracy / iterations
    print(f"Epoch {epoch}/3 ==> train loss: {loss}, train acc: {acc}")

In [None]:
z = torch.randn((1000, latent_dim)).to(device)

# Generate a batch of images
with torch.no_grad():
    gen_imgs = generator(z)
    y_pred = clf(gen_imgs)
    
y_pred = np.argmax(y_pred.cpu().numpy(), axis=1)

In [None]:
class_distributions = [np.sum(y_pred == i) for i in range(10)]
print(class_distributions)

In [None]:
plt.bar(list(range(10)), class_distributions, tick_label=list(range(10)))
plt.ylabel("Number of predictions")
plt.xlabel("Class")
plt.plot()

In [None]:
def to_onehot(digits, num_classes):
    """ [[3]] => [[0, 0, 1]]
    """
    labels_onehot = torch.zeros(digits.shape[0], num_classes).to(device)
    labels_onehot.scatter_(1, digits.view(-1, 1), 1)
    return labels_onehot

class Generator(nn.Module):
    def __init__(self, num_classes):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + num_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z, y):
        z =  torch.cat((z, y), dim=1)
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)) + num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img, y):
        img_flat = img.view(img.size(0), -1)
        img_flat =  torch.cat((img_flat, y), dim=1)

        validity = self.model(img_flat)

        return validity


# Initialize generator and discriminator
generator = Generator(10)
discriminator = Discriminator(10)

generator.to(device)
discriminator.to(device)
bce_loss.to(device)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, (imgs, y) in enumerate(dataloader):

        # Adversarial ground truths
        # We use the Cross Entropy (CE) loss. So we need labels. Define them here:
        """There is something missing here"""
        
        y = y.to(device)
        y = to_onehot(y, 10)
    
        # Configure input
        real_imgs =  imgs.to(device)


        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        # Generate a batch of images
        gen_imgs = generator(z, y)
        
        y_pred_fake = discriminator(gen_imgs, y)
        
        # Loss measures generator's ability to fool the discriminator
        g_loss = bce_loss(y_pred_fake, torch.zeros_like(y_pred_fake))

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)
        
        gen_imgs = generator(z, y)
        
        y_pred_real = discriminator(real_imgs, y)
        y_pred_fake = discriminator(gen_imgs, y)

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = bce_loss(y_pred_real, torch.zeros_like(y_pred_real))
        fake_loss = bce_loss(y_pred_fake, torch.ones_like(y_pred_real))
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        if i%200 == 0:
            print(
              "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
              % (epoch+1, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            # You can also safe samples in your drive & maybe save your network as well
            save_image(gen_imgs.data[:25], "images_cgan/GAN-%d.png" % batches_done, nrow=5, normalize=True)

In [None]:
z = torch.randn((1000, latent_dim)).to(device)

y = [i % 100 for i in range(1000)]

# Generate a batch of images
with torch.no_grad():
    y = torch.LongTensor([i % 10 for i in range(1000)]).to(device)
    y = to_onehot(y, 10)
    gen_imgs = generator(z, y)
    y_pred = clf(gen_imgs)
    
y_pred = np.argmax(y_pred.cpu().numpy(), axis=1)

class_distributions = [np.sum(y_pred == i) for i in range(10)]
print(class_distributions)

plt.bar(list(range(10)), class_distributions, tick_label=list(range(10)))
plt.ylabel("Number of predictions")
plt.xlabel("Class")
plt.plot()