In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

In [None]:
!pip3 install jupyterthemes

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchsummary import summary
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import os
import sys
from datetime import datetime
from tqdm.notebook import tqdm
from jupyterthemes import jtplot
jtplot.style(theme="monokai", context="notebook", ticks=True)



In [None]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5, ), std=(0.5, )),
    transforms.Resize((64, 64))])

In [None]:
train_data = torchvision.datasets.CIFAR10(root="./DATA/", transform=transform, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./DATA/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./DATA/cifar-10-python.tar.gz to ./DATA/


In [None]:
def get_discriminator_block(input_channel, output_channel, kernel_size=3,
                            stride=2, padding=0, final_layer=False, batchNorm=True):
    
    """
    Description : Function to build the discriminator neural block

    Parameters:
    @param input_channel -- a python integer representing the input channel
    @param output_channel -- a python integer representing the output channel
    @param kernel_size -- a python integer representing the kernel size (by default=3)
    @param stride -- a python integer representing the stride (by default=2)
    @param padding -- a python integer representing the padding (by default=0)
    @param final_layer -- boolean value representing whether it is the last layer (by default=False)
    @param batchNorm -- boolean value representing whether to apply batchNorm or not (by default=True)

    Return:
    @ret disc_block -- A sequential neural block
    """

    disc_block = None

    if not final_layer and batchNorm:
        disc_block = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )

    elif not batchNorm and not final_layer:
        disc_block = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )

    else:
        disc_block = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        )

    return disc_block




In [None]:
def get_generator_block(input_channel, output_channel, kernel_size=3, stride=2,
                        padding=0, final_layer=False):
    
    """
    Description : Function to create the generator neural block

    Parameters:
    @param input_channel -- a python integer representing the input channel
    @param output_channel -- a python integer representing the output channel
    @param kernel_size -- a python integer representing the kernel size (by default=3)
    @param stride -- a python integer representing the stride value (by default=2)
    @param padding -- a python integer representing the padding value (by default=0)
    @param final_layer -- a boolean value representing whether it is the final layer or not (by default=False)

    Return:
    @ret gen_block -- a Sequential neural block
    """

    gen_block = None

    if not final_layer:
        gen_block = nn.Sequential(
            nn.ConvTranspose2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True)
        )

    else:
        gen_block = nn.Sequential(
            nn.ConvTranspose2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.Tanh()
        )

    return gen_block

In [None]:
def get_noise(batch_size, latent_dim, device):
    """
    Description : Function to generate noise 

    Parameters : 
    @param batch_size -- the batch size for the noise
    @param latent_dim -- the dimension for the noise
    @param device -- the device to transfer to

    Return :
    @ret noise -- the noise value for the current batch size 
    """

    noise = torch.randn(batch_size, latent_dim, device=device)
    return noise

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self, input_channel, hidden_channel=64):

        super(Discriminator, self).__init__()

        self.disc = nn.Sequential(
            get_discriminator_block(input_channel, hidden_channel, kernel_size=4, stride=2, padding=1, batchNorm=False),
            get_discriminator_block(hidden_channel, hidden_channel*2, kernel_size=4, stride=2, padding=1),
            get_discriminator_block(hidden_channel*2, hidden_channel*4, kernel_size=4, stride=2, padding=1),
            get_discriminator_block(hidden_channel*4, hidden_channel*8, kernel_size=4, stride=2, padding=1),
            get_discriminator_block(hidden_channel*8, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
        )

    def forward(self, X):
        return self.disc(X)

In [None]:
class Generator(nn.Module):

    def __init__(self, latent_dim, im_channels=3, hidden_channel=64):

        super(Generator, self).__init__()
        
        self.latent_dim = latent_dim
        self.gen = nn.Sequential(
            get_generator_block(latent_dim, hidden_channel*8, kernel_size=4, stride=1, padding=0),
            get_generator_block(hidden_channel*8, hidden_channel*4, kernel_size=4, stride=2, padding=1),
            get_generator_block(hidden_channel*4, hidden_channel*2, kernel_size=4, stride=2, padding=1),
            get_generator_block(hidden_channel*2, hidden_channel, kernel_size=4, stride=2, padding=1),
            get_generator_block(hidden_channel, im_channels, kernel_size=4, stride=2, padding=1, final_layer=True)
        )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.latent_dim, 1, 1)

    def forward(self, noise):
        noise = self.unsqueeze_noise(noise)
        return self.gen(noise)

In [None]:
def get_discriminator_loss(G, D, criterion, real_data, latent_dim, batch_size, device):

    """
    Description : Function to calculate the loss of the discriminator

    Parameters:
    @param G : the Generator network
    @param D : the Discriminator network
    @param real_data : the images 
    @param criterion : the loss function 
    @param latent_dim : the latent dimension for the noise vector
    @param batch_size : the batch_size 
    @param device : the device to transfer to

    Return:
    @ret disc_loss : the Discriminator Loss
    """

    real_pred = D(real_data)
    ones_ = torch.ones_like(real_pred).to(device)
    real_loss = criterion(real_pred, ones_)

    noise = get_noise(batch_size, latent_dim, device)
    fake_img = G(noise).detach()
    fake_pred = D(fake_img)
    zeros_ = torch.zeros_like(fake_pred).to(device)
    fake_loss = criterion(fake_pred, zeros_)

    disc_loss = 0.5 * (fake_loss + real_loss)

    return disc_loss

In [None]:
def get_generator_loss(G, D, criterion, latent_dim, batch_size, device):

    """
    Description : Function to calculate the loss of the generator

    Parameters:
    @param G : the Generator network
    @param D : the Discriminator network
    @param criterion : the loss function 
    @param latent_dim : the latent dimension for the noise vector
    @param batch_size : the batch_size 
    @param device : the device to transfer to

    Return:
    @ret gen_loss : the Discriminator Loss
    """

    noise = get_noise(batch_size, latent_dim, device)
    fake_img = G(noise)
    fake_pred = D(fake_img)
    ones_ = torch.ones_like(fake_pred).to(device)
    gen_loss = criterion(fake_pred, ones_)
    return gen_loss

In [None]:
def init_weight(m):

    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, mean = 0.0, std=0.02)

    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        torch.nn.init.normal_(m.weight, mean=0.0)

In [None]:
def train(index, flags):
    
    g_losses = []
    d_losses = []

    torch.manual_seed(flags['seed'])

    device = xm.xla_device()

    dataset = flags['dataset']

    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, 
        num_replicas = xm.xrt_world_size(),
        rank = xm.get_ordinal(),
        shuffle = True
    )

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size = flags['batch_size'],
        sampler = sampler,
        num_workers = flags['num_workers'],
        drop_last = True
    )

    G = Generator(flags['latent_dim'])
    D = Discriminator(flags['color_channel'])
    
    # print("========== Summary of the Generator ============")
    # print(summary(G, (100, )))

    # print("------------- Summary of the Discriminator -------------")
    # print(summary(D, (3, 64, 64)))

    G = G.to(device)
    D = D.to(device)

    G = G.apply(init_weight)
    D = D.apply(init_weight)

    criterion = nn.BCEWithLogitsLoss()
    g_optim = torch.optim.Adam(G.parameters(), lr=1e-4 * 4, betas = (0.5, 0.999))
    d_optim = torch.optim.Adam(D.parameters(), lr=1e-4 * 2, betas = (0.5, 0.999))

    for epoch in range(flags['epochs']):

        para_loader = pl.ParallelLoader(data_loader, [device]).per_device_loader(device)
        t0 = datetime.now()
        g_loss = []
        d_loss = []
        batch_size = 0
        for batch in tqdm(para_loader):

            data, _ = batch

            data = data.to(device)

            batch_size = data.size(0)

            ####################################
            ###### TRAIN DISCRIMINATOR #########
            #####################################
            d_optim.zero_grad()
            dLoss = get_discriminator_loss(G, D, criterion, data, 
                                           flags['latent_dim'], batch_size, device)
            
            dLoss.backward()
            xm.optimizer_step(d_optim)

            d_loss.append(dLoss.item())

            #######################################
            ############ TRAIN GENERATOR ##########
            #######################################

            g_optim.zero_grad()
            gLoss = get_generator_loss(G, D, criterion, 
                                       flags['latent_dim'], batch_size, device)
            
            gLoss.backward()
            xm.optimizer_step(g_optim)

            g_loss.append(gLoss.item())

        
        g_loss = np.mean(g_loss)
        d_loss = np.mean(d_loss)

        g_losses.append(g_loss)
        d_losses.append(d_loss)

        print(f"Epoch : {epoch+1}/{flags['epochs']} || Disc Loss : {d_loss} || Gen Loss : {g_loss} || Time elapsed : {datetime.now() - t0} || Process : {index}")

        if flags['save'] and (epoch+1) % 10 == 0:
            noise = get_noise(batch_size, flags['latent_dim'], device)
            fake_img = G(noise).detach()
            save_image(fake_img, f"./DCGANS_CIFAR10_multi_tpu/gan_{epoch + 1}_index_{index}.png", normalize=True)

    
    return G, D, g_losses, d_losses




In [None]:
if not os.path.exists("./DCGANS_CIFAR10_multi_tpu/"):
    os.mkdir("./DCGANS_CIFAR10_multi_tpu")

In [None]:
flags = {
    "num_workers" : 8,
    "batch_size" : 128,
    "epochs" : 200,
    "latent_dim" : 100,
    "color_channel" : 3,
    "seed" : 1234,
    "dataset" : train_data,
    "save" : True
}

In [None]:
G, D, g_losses, d_losses = xmp.spawn(train, args=(flags, ), nprocs=8, start_method='fork')

In [None]:
!zip -r "images.zip" ./DCGANS_CIFAR10_multi_tpu/

In [None]:
from google.colab import files
files.download("images.zip")