In [1]:
import os 
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, datasets, transforms
import torch.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
root = "C:\\Users\\shant\\celeba"
LEARNING_RATE = 0.00001
BATCH_SIZE = 64
Z_DIM = 100
FEATURES_DISC = 16
FEATURES_GEN = 16
IMAGE_CHANNELS = 1
NUM_EPOCHS = 5
IMAGE_SIZE = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
NUM_CLASSES = 10
GEN_EMBEDDING = 100

transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(IMAGE_CHANNELS)], [0.5 for _ in range(IMAGE_CHANNELS)])
])

dataset = datasets.MNIST(root = "../dataset", transform = transforms, download = True)

dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

In [4]:
class Generator(nn.Module):
    """
    z_dim: 
    channels_img: Input channels(for example for an RGB image this value is 3)
    features_g: Size of the output feature map(In this case its 64x64)
    """
    def __init__(self, z_dim, channels_img, features_g, num_classes, img_size, embed_size):
        
        super(Generator, self).__init__()
        self.img_size = img_size
        self.embed  = nn.Embedding(num_classes, embed_size)
        self.gen = nn.Sequential(self._block(z_dim + embed_size, features_g*16, 4, 1, 0),
                                 self._block(features_g*16, features_g*8, 4, 2, 1),
                                 self._block(features_g*8, features_g*4, 4, 2, 1),
                                 self._block(features_g*4, features_g*2, 4, 2, 1),
                                 nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1),
                                 nn.Tanh())
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x, labels):
        # latent vector z: N x noise_dim x 1 x 1
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3) # Basically adds dimension 1 and then again dimension 1
        x = torch.cat([x,embedding], dim = 1)
        return self.gen(x)

class Critic(nn.Module):
    
    def __init__(self, channels_img, features_d, num_classes, img_size):
        
        super(Critic, self).__init__()
        self.img_size = img_size
        self.embed = nn.Embedding(num_classes, img_size*img_size)
        
        ## Additional stamp as to what the image is 
        self.critic = nn.Sequential(nn.Conv2d(channels_img + 1, features_d, kernel_size = 4, stride = 2, padding = 1), #32x32
                                    nn.LeakyReLU(0.2, inplace=True),
                                    self._block(features_d, features_d*2, 4, 2, 1),# 16x16
                                    self._block(features_d*2, features_d*4, 4, 2, 1), #8x8
                                    self._block(features_d*4, features_d*8, 4, 2, 1), #4x4
                                    nn.Conv2d(features_d*8, 1, kernel_size = 4, stride = 2, padding = 0)
                                    )
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
         nn.Conv2d(in_channels, out_channels,kernel_size, stride, padding, bias=False),
            nn.InstanceNorm2d(out_channels, affine = True),
            nn.LeakyReLU(0.2, inplace=True)
        )
    
    def forward(self,x,labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x,embedding], dim = 1) # N x C x H x W
        return self.critic(x)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1, 0.02)
        nn.init.constant_(m.bias.data, 0)

def gradient_penalty(critic, real, fake, labels, device = "cpu"):
    BATCH_SIZE, C, H, W = real.shape
    # Creating interpolated images
    epsilon = torch.randn([BATCH_SIZE, 1, 1, 1]).repeat(1,C,H,W).to(device)
    interpolated_images = real*epsilon + fake * (1-epsilon)

    #calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Compute the gradients with respect to the interpolated images, just need the first value
    gradient = torch.autograd.grad(inputs = interpolated_images, 
    outputs = mixed_scores, 
    grad_outputs = torch.ones_like(mixed_scores),
    create_graph=True,
    retain_graph=True)[0]

    # Number of Dimension
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim = 1)
    gradient_penalty = torch.mean((gradient_norm - 1)**2)

    return gradient_penalty

In [5]:
########################
# Generator and Discriminator Model objects
########################
generator = Generator(Z_DIM,IMAGE_CHANNELS,FEATURES_GEN,NUM_CLASSES,IMAGE_SIZE,GEN_EMBEDDING).to(device)
critic = Critic(IMAGE_CHANNELS,FEATURES_DISC,NUM_CLASSES,IMAGE_SIZE).to(device)

########################
# Weight Initialization for the model
########################
generator.apply(weights_init)
critic.apply(weights_init)

########################
# Optimizers for Critic and the Generator
########################
optimizer_gen = optim.Adam(generator.parameters(), lr = LEARNING_RATE, betas = (0,0.9))
optimizer_critic = optim.Adam(critic.parameters(), lr = LEARNING_RATE, betas = (0,0.9))

#######################
# Create tensorboard SummaryWriter objects to display generated fake images and associated loss curves
#######################
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
loss_curves = SummaryWriter(f"logs/loss_curves")

#######################
# Create a batch of latent vectors. Will be used to to do a single pass through the generator after 
# the training has terminated
#######################
fixed_noise = torch.randn((64, Z_DIM, 1, 1)).to(device)

step = 0 # For printing to tens

In [6]:
for epoch in range(NUM_EPOCHS):
    
    # Unsupervised
    for batch_idx, (real, labels) in enumerate(dataloader):
        
        # The real world images
        real = real.to(device)
        
        cur_batch_size = real.shape[0]
        
        labels = labels.to(device)

        #####################################################
        # Train the Critic
        #####################################################
        
        for _ in range(CRITIC_ITERATIONS):
            critic.zero_grad()
            # Latent noise
            noise = torch.randn((cur_batch_size, Z_DIM, 1, 1)).to(device)
            # Pass the latent vector through the generator
            fake = generator(noise)     
            critic_real = critic(real, labels).view(-1)
            critic_fake = critic(fake.detach(), labels).view(-1)
            gp = gradient_penalty(critic, real, fake, labels, device=device)
            ## Loss for the critic. Taking -ve because RMSProp are designed to minimize 
            ## Hence to minimize something -ve is equivalent to maximizing that expression
            loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp) 
            loss_critic.backward(retain_graph=True)
            optimizer_critic.step()

        #############################
        # Train the generator minimizing -E[critic(gen_fake)]
        #############################
        generator.zero_grad()
        output = critic(fake, labels).view(-1)
        loss_gen = -torch.mean(output)
        loss_gen.backward()
        optimizer_gen.step()
        
        if batch_idx % 100 == 0:
            
            print(
            f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )
            with torch.no_grad():
                fake = generator(noise, labels)
            
                # The [:64] prints out the 4-D tensor BxCxHxW
                img_grid_real = torchvision.utils.make_grid(
                    real[:64], normalize = True)
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:64], normalize = True)
                ##########################
                # TensorBoard Visualizations
                ##########################
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
#                 loss_curves.add_scalar("generator", {loss_gen, global_step=step)
                loss_curves.add_scalars("curves", {
                    "generator":loss_gen, "critic":loss_critic
                }, global_step = step)
#                 loss_curves.add_scalar("discriminator", loss_disc, global_step = step)
                
            step += 1 # See progression of images

IndexError: too many indices for tensor of dimension 0