# Generative Adverserial Network

In [None]:
import numpy as np
import os
from tqdm import tqdm

import torch
from torch import nn

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

In [None]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator,self).__init__()
        
        def block(in_feat,out_feat):
            x = []
            x.append(nn.Linear(in_feat,out_feat))
            x.append(nn.BatchNorm1d(out_feat))
            x.append(nn.ReLU(inplace=True)) #replace with leaky relu for better gradients transfer
            return x
        
        self.model = nn.Sequential(
            *block(n_latent,128),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,np.prod(img_shape)),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        x = self.model(x)
        img = x.view(-1,*img_shape) # batch x channels x H x W
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(np.prod(img_shape),512),
            nn.ReLU(inplace=True), 
            nn.Linear(512,256),
            nn.ReLU(inplace=True),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        x = x.view(-1,np.prod(img_shape))
        return self.model(x)

# Training

In [None]:
# training options
n_latent = 10 #dimension of input Latent vector
batch_size = 32
num_epochs = 100
img_shape = (1,28,28)
save_after = 10
load = -1 #path to checkpoint

# intialize generator and discriminator
gen = Generator().cuda()
dis = Discriminator().cuda()

# define Loss function
loss = nn.BCELoss().cuda()

# create optimizers 
optimizer_G = torch.optim.Adam(gen.parameters(),lr=2e-4) #GANs are highly sensitive to LRs
optimizer_D = torch.optim.Adam(dis.parameters(),lr=2e-4)

# Dataloader

mnist = datasets.MNIST(
        "data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_shape[1:]), transforms.ToTensor(), transforms.Normalize([0], [1])]
        ),
    )

dataloader = torch.utils.data.DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

# display model params
total_params = 0
for model in [gen,dis]:
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params+=params
#     print("params in {} : {}".format(model,params)) #uncomment for network architecture
print("Total trainable params: {}".format(total_params))

# summary writer
writer = SummaryWriter(log_dir='logs')

In [None]:
# training Loop

if(load != -1):
    checkpoint = torch.load('checkpoints/'+load)
    start_epoch = checkpoint['epoch']
    global_step = checkpoint['global_step']
    gen.load_state_dict(checkpoint['gen_state_dict'])
    dis.load_state_dict(checkpoint['dis_state_dict'])
    optimizer_G.load_state_dict(checkpoint['genopt_state_dict'])
    optimizer_D.load_state_dict(checkpoint['disopt_state_dict'])
else:
    global_step = 0
    start_epoch = 0


for epoch in tqdm(range(start_epoch,num_epochs)):
    
    gen.train()
    dis.train()
    for i,(real_images,label) in enumerate(dataloader):
        global_step+=1
        
        #fake - 0, real -1
        real = torch.tensor(np.ones((batch_size,1)),dtype=torch.float,).cuda()
        fake = torch.tensor(np.zeros((batch_size,1)),dtype=torch.float).cuda()

        #generator training
        optimizer_G.zero_grad()

        x = torch.Tensor(np.random.normal(size = (batch_size,n_latent))).cuda()
        generated_images = gen(x)

        g_loss = loss(dis(generated_images),real)

        g_loss.backward()
        optimizer_G.step()
        
        writer.add_scalar('Generator loss',g_loss,global_step=global_step)

        #discriminator training
        optimizer_D.zero_grad()

        real_loss = loss(dis(real_images.cuda()),real)
        fake_loss = loss(dis(generated_images.detach()),fake)

        d_loss = real_loss+fake_loss
        d_loss.backward()
        optimizer_D.step()
        
        writer.add_scalar('Discriminator loss',d_loss,global_step=global_step)
    
    if not os.path.exists('images'): 
        os.mkdir('images')
    
    #saving generated images
    gen.eval()
    x = torch.Tensor(np.random.normal(size = (16,n_latent))).cuda()
    generated_images = gen(x)
    save_image(generated_images.data,"images/{}.png".format(str(epoch).zfill(4)),nrow=4,normalize=True)

    if(epoch % save_after == 0):
        if not os.path.exists('checkpoints'):
            os.mkdir('checkpoints')
 
        torch.save({
            'epoch':epoch,
            'global_step':global_step,
            'gen_state_dict': gen.state_dict(),
            'dis_state_dict': dis.state_dict(),
            'genopt_state_dict': optimizer_G.state_dict(),
            'disopt_state_dict': optimizer_D.state_dict()
        },'checkpoints/'+str(epoch).zfill(4)+'.pth')

       
print('Finished Training !')

## Inference

In [None]:
n = 100 #number of images to be generated
n_latent = 10
img_shape = (1,28,28)
load = '0010.pth'

gen = Generator().cuda()

checkpoint = torch.load('checkpoints/'+load)
gen.load_state_dict(checkpoint['gen_state_dict'])

gen.eval()
x = torch.Tensor(np.random.normal(size = (n,n_latent))).cuda()
generated_images = gen(x)
save_image(generated_images.data,"sample.png",nrow=int(n**0.5),normalize=True)
