### Import Libraries

In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F 
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid, save_image

import os
import numpy as np 
import matplotlib.pyplot as plt

### Learning Parameters

In [2]:
batch_size = 512
epochs = 2
sample_size = 64 #fixed sample size
flatten_image_size = 784
nz = 128 # latent vector size 
k = 1 # no. of steps to apply to the discriminator
lr = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Preparing the Dataset

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

to_pil_image = transforms.ToPILImage() # This is required when we want to save the generated images as a .gif file.

train_data = datasets.MNIST(
    root='./input',
    train=True,
    download=True,
    transform=transform,
)

train_loader = DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True
)

### Generator Network

In [4]:
class GenNet(nn.Module):
    def __init__(self, nz):
        super(GenNet, self).__init__()
        
        self.linear1 = nn.Linear(nz, 256)
        self.linear2 = nn.Linear(256, 512)
        self.linear3 = nn.Linear(512, 1024)
        self.linear4 = nn.Linear(1024, 784)
    
    def forward(self, x):
        x = F.leaky_relu(self.linear1(x), 0.2)
        x = F.leaky_relu(self.linear2(x), 0.2)
        x = F.leaky_relu(self.linear3(x), 0.2)
        x = torch.tanh(self.linear4(x))
        
        return x.view(-1, 1, 28, 28)

### Discriminator Network

In [5]:
class DisNet(nn.Module):
    def __init__(self, in_size):
        super(DisNet, self).__init__()
        
        self.linear1 = nn.Linear(in_size, 1024)
        self.dropout1 = nn.Dropout(0.3)
        
        self.linear2 = nn.Linear(1024, 512)
        self.dropout2 = nn.Dropout(0.3)
        
        self.linear3 = nn.Linear(512, 256)
        self.dropout3 = nn.Dropout(0.3)
        
        self.linear4 = nn.Linear(256, 1)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.dropout1(F.leaky_relu(self.linear1(x), 0.2))
        x = self.dropout2(F.leaky_relu(self.linear2(x), 0.2))
        x = self.dropout3(F.leaky_relu(self.linear3(x), 0.2))
        x = torch.sigmoid(self.linear4(x))
        
        return x 

In [6]:
generator = GenNet(nz).to(device)
discriminator = DisNet(flatten_image_size).to(device)

In [7]:
print(generator)
print(discriminator)

GenNet(
  (linear1): Linear(in_features=128, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=512, bias=True)
  (linear3): Linear(in_features=512, out_features=1024, bias=True)
  (linear4): Linear(in_features=1024, out_features=784, bias=True)
)
DisNet(
  (linear1): Linear(in_features=784, out_features=1024, bias=True)
  (dropout1): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1024, out_features=512, bias=True)
  (dropout2): Dropout(p=0.3, inplace=False)
  (linear3): Linear(in_features=512, out_features=256, bias=True)
  (dropout3): Dropout(p=0.3, inplace=False)
  (linear4): Linear(in_features=256, out_features=1, bias=True)
)


In [8]:
optim_gen = optim.Adam(generator.parameters(), lr=lr)
optim_dis = optim.Adam(generator.parameters(), lr=lr)

In [9]:
print(optim_gen)
print(optim_gen)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)


### Criterion

In [10]:
criterion = nn.BCELoss()

### Utils function

In [11]:
def latent_vector(size, nz):
    return torch.randn(size, nz).to(device)

def make_dir(path="./output/images"):
    if not os.path.exists(path):
        os.makedirs(path)     
make_dir()

def save_generator_image(image, path):
    save_image(image, path)
    
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)

def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

### Discriminator Training Function

In [12]:
def train_discriminator(optim, real_data, fake_data):
    b_size = real_data.size(0)

    real_label = label_real(b_size)
    fake_label = label_fake(b_size)
    
    optim.zero_grad()
    
    output_real = discriminator(real_data)
    loss_real = criterion(output_real, real_label)

    output_fake = discriminator(fake_data)
    loss_fake = criterion(output_fake, fake_label)
    
    loss_real.backward()
    loss_fake.backward()
    
    optim.step()
    
    total_loss = loss_real+loss_fake
    
    return total_loss

### Generator Training Function

In [13]:
def train_generator(optim, fake_data):
    b_size = fake_data.size(0)
    real_label = label_real(b_size)
    
    optim.zero_grad()
    
    output = discriminator(fake_data)
    loss = criterion(output, real_label)
    
    loss.backward()
    
    optim.step()
    
    return loss

### Training Loop

In [14]:
losses_g = [] # to store generator loss after each epoch
losses_d = [] # to store discriminator loss after each epoch
images = [] # to store images generatd by the generator

In [15]:
for epoch in range(epochs):
    
    loss_g = 0.0
    loss_d = 0.0 
    
    for i, (image,_) in enumerate(train_loader):
        image = image.to(device)
        b_size = image.size(0)
        noise = latent_vector(b_size, nz)
        for step in range(k):
            data_fake = generator(noise).detach()
            data_real = image 
            
            loss_d += train_discriminator(optim_dis, data_real, data_fake)
            
        data_fake = generator(noise)
        loss_g += train_generator(optim_gen, data_fake)
        
    #create the final fake image for the epoch
    generated_img = generator(latent_vector(sample_size, nz)).cpu().detach()
    #make the images as grid
    generated_img = make_grid(generated_img)
    #Save the generated torch tensor models to disk 
    save_generator_image(generated_img, f"./output/images/gen_img{epoch}.png")
    
    images.append(generated_img)
    
    epoch_loss_g = loss_g/i
    epoch_loss_d = loss_d/(i*k)
    
    losses_d.append(epoch_loss_d)
    losses_g.append(epoch_loss_g)
    
    print(f"{epoch+1}/{epochs}")
    print(f"Generator Loss: {epoch_loss_g:.4f}, Discriminator Loss: {epoch_loss_d:.4f}")
    

1/2
Generator Loss: 0.4461, Discriminator Loss: 1.7893
2/2
Generator Loss: 0.4349, Discriminator Loss: 1.8048
