# Dependencies

In [1]:
import numpy as np
from tqdm.auto import tqdm
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import matplotlib.pyplot as plt

# Utils

In [2]:
def display(images, nrow):
    image_grid = torchvision.utils.make_grid(images, nrow=nrow)
    imgs = image_grid.permute(1, 2, 0).cpu().detach().numpy()

    # Normalize image between 0 and 1.
    min_val = np.min(imgs)
    max_val = np.max(imgs)
    imgs = (imgs - min_val) / (max_val - min_val)

    fig = plt.figure(figsize=(16, 16))
    plt.imshow(imgs)
    plt.show()

# Dataset

Data variables.

In [3]:
img_size=64
img_channels=1
num_classes=10 

noise_dim=64

In [4]:
transforms=torchvision.transforms.Compose(
    [
     torchvision.transforms.Resize((img_size,img_size)),
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize([0.5],[0.5])]
)

dataset=torchvision.datasets.FashionMNIST(root=".", train=True, transform=transforms, download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/train-images-idx3-ubyte.gz to ./FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw



# Model

## Discriminator

In [5]:
class Discriminator(nn.Module):
  def __init__(self, input_channels, features_d=64):
    super(Discriminator,self).__init__()
    
    self.net=nn.Sequential(
        # N x input_channels x 64 x 64.
        nn.Conv2d(
            input_channels, features_d, kernel_size=4, stride=2, padding=1
        ),
        # N x features_d x 32 x 32.
        nn.LeakyReLU(negative_slope=0.2),
        self.block(features_d,features_d*2),
        # N x features_d*2 x 16 x 16.
        self.block(features_d*2,features_d*4),
        # N x features_d*4 x 8 x 8.
        self.block(features_d*4,features_d*8),
        # N x features_d*8 x 4 x 4.
        nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
        # N x 1 x 1 x 1. 
        nn.Sigmoid()
    )

  def block(
      self,
      in_channels,
      out_channels,
      kernel_size=4,
      stride=2,
      padding=1
    ):
    return nn.Sequential(
        nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(negative_slope=0.2)
    )

  def forward(self, x):
    return self.net(x)

## Generator

In [6]:
class Generator(nn.Module):
  def __init__(self, input_dim, img_channels, features_g=64):
    super(Generator, self).__init__()
    
    self.net=nn.Sequential(
        # N x input_dim x 1 x 1.
        self.block(input_dim, features_g*16, stride=1, padding=0),
        # N x features_g*16 x 4 x 4.
        self.block(features_g*16, features_g*8),
        # N x features_g*8 x 8 x 8.
        self.block(features_g*8, features_g*4),
        # N x features_g*4 x 16 x 16.
        self.block(features_g*4, features_g*2),
        # N x features_g*2 x 32 x 32.
        nn.ConvTranspose2d(
            features_g*2, img_channels, kernel_size=4, stride=2, padding=1
        ),
        # N x img_channels x 64 x 64. Output pixel values will be in range [-1,1].
        nn.Tanh()
    )
  
  def block(
      self,
      in_channels,
      out_channels,
      kernel_size=4,
      stride=2,
      padding=1
    ):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size, stride,padding, bias=False
        ),
        nn.BatchNorm2d(out_channels), 
        nn.ReLU()
    )
  
  def forward(self,x):
    return self.net(x)

# Setup

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Prepare the dataloader.

In [8]:
batch_size=128
dataset_loader=DataLoader(dataset, batch_size=batch_size, shuffle=True)

Prepare models.

In [9]:
# The tensor given as input to the generator will be of shape: batch_size x (z_dim + num_classes) x 1 x 1
g = Generator(noise_dim + num_classes, img_channels).to(device)

# The tensor given as input to the discriminator will be of shape: batch_size x (img_channels + num_classes) x img_size x img_size
d=Discriminator(img_channels+num_classes).to(device)

In [10]:
# def weights_init(m):
#   if (isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d)):
#     nn.init.normal_(m.weight,0.0,0.02)
#   elif (isinstance(m,nn.BatchNorm2d)):
#     nn.init.normal_(m.weight,0.0,0.02)
#     nn.init.constant_(m.bias,0)

# gen=g.apply(weights_init)
# disc=d.apply(weights_init)

Optimizer.

In [11]:
lr=2e-4 

g_opt=torch.optim.Adam(g.parameters(),lr=lr,betas=(0.5,0.999))
d_opt=torch.optim.Adam(d.parameters(),lr=lr,betas=(0.5,0.999))
loss_fn=nn.BCELoss()

We "validate" displaying the same noise vector.

In [12]:
# The number of images we want to check for in each class.
ncol = 10

val_oh = F.one_hot(
    torch.arange(end=ncol * num_classes, device=device) % num_classes,
    num_classes=num_classes,
)[..., None, None]

## Training

In [14]:
#1 epoch = Used all of the images from the training set once to train the discriminator. Feel free to change this value to train longer.
n_epochs=10

#Use mean_gen_loss and mean_disc_loss to display the average loss values after 'delta_step' steps
mean_g_loss=0.0 
mean_d_loss=0.0

#Current step number
step=0 

#After 'd_s' steps, display the average loss values of previous 'd_s' steps
delta_step=len(dataset_loader) 

for epoch in range(n_epochs):
  for images,labels in tqdm(dataset_loader, leave=False):
    #Use GPU if available
    images=images.to(device)
    labels=labels.to(device)

    num_examples=images.shape[0]

    oh = F.one_hot(labels, num_classes).unsqueeze(-1).unsqueeze(-1)
    
    #Training the discriminator

    #Combine ohel with random noise vector which is sampled from normal distribution so that the generator comes to know which class we want it to generate
    g_in = torch.cat((torch.randn((num_examples, noise_dim, 1, 1), device=device), oh.float()), dim=1)
    #Generator generates fake images 
    fake_images=g(g_in)

    #Convert ohel of shape batch_size x num_classes x 1 x 1 to batch_size x num_classes x img_size x img_size
    #basically, convert it to one hot encoded 4D array such that only the correct class 2D array has all 1s and the rest have 0s in them for every image 
    oh_d = oh.clone().repeat(1,1,img_size,img_size)  
    
    #To let discriminator know which class of images it will look at, we pass class information along the img_channel dimension
    #So the images tensor is of shape: batch_size x img_channels x img_size x img_size
    #ohel_disc tensor is of shape: batch_size x num_classes x img_size x img_size
    #combined tensor will have shape: batch_size x (img_channels + num_classes) x img_size x img_size
    combined_real_disc = torch.cat((images, oh_d.float()), dim=1)
    combined_fake_disc = torch.cat((fake_images.detach(), oh_d.float()), dim=1)

    d_opt.zero_grad()
    
    #fake_preds - Probabilities assigned to batch_size number of generated images of being fake 
    fake_preds = d(combined_fake_disc)
    d_loss = loss_fn(fake_preds, torch.zeros_like(fake_preds))
    
    #real_preds - Probabilities assigned to batch_size number of training images of being real
    real_preds = d(combined_real_disc)
    d_loss += loss_fn(real_preds,torch.ones_like(real_preds))
    d_loss /=2 #Average
    d_loss.backward()
    d_opt.step()

    #Training the generator 

    g_opt.zero_grad()
    
    new_images = g(torch.cat((torch.randn((num_examples, noise_dim, 1, 1), device=device), oh), dim=1))
    combined_new_images_disc = torch.cat((new_images, oh_d.float()), dim=1)
    preds = d(combined_new_images_disc)
    g_loss= loss_fn(preds,torch.ones_like(preds))
    g_loss.backward()
    g_opt.step()
    
    #For displaying loss after 'd_s' steps
    mean_g_loss += g_loss.item()/delta_step
    mean_d_loss += d_loss.item()/delta_step
    if (step % delta_step == 0): # 1 will be delta_sep
        print(f"Epoch: {epoch+1}, Step: {step + 1}, Discriminator loss: {mean_d_loss:.4f}, Generator loss: {mean_g_loss:.4f}")

        # Validation.
        z = torch.randn((num_classes*ncol, noise_dim, 1, 1), device=device)
        imgs = torch.cat((z, val_oh), dim=1)
        out = g(imgs)
      
        display(images=g(imgs), nrow=ncol)

        mean_disc_loss=0
        mean_gen_loss=0
    
    step += 1

Output hidden; open in https://colab.research.google.com to view.