## E0270 Machine Learning Course Project
# SAGAN implementation
Name : Mohit Kumar<br>
M. Tech in Artificial Intelligence<br>
SR No. : 04-01-03-10-51-21-1-19825<br>
email : mohitk2@iisc.ac.in<br>

I learnt using the pytorch library from [PyTorch Tutorials by Aladdin Persson](https://www.youtube.com/playlist?list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz)

For implementing the SAGAN I took help from the YouTube videos<br> [PyTorch Conditional GAN Tutorial](https://www.youtube.com/watch?v=Hp-jWm2SzR8&list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&index=5&t=207s)<br>
[WGAN implementation from scratch(with gradient penalty)](https://www.youtube.com/watch?v=pG0QZ7OddX4&list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&index=4&t=1370s)<br>
and the official SAGAN paper [Self-Attention Generative Adversarial Networks](https://proceedings.mlr.press/v97/zhang19d.html) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena

In [1]:
# imports
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
import matplotlib.animation as animation
from IPython.display import HTML
import pandas as pd
import torchvision
import time
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
from PIL import Image
import matplotlib.pyplot as plt
import sys
import os
from google.colab import drive

drive.mount('/content/drive/')
%cd /content/drive/'My Drive'/GAN/
!ls

Mounted at /content/drive/
/content/drive/My Drive/GAN
checkpoint  dataset  outputs


## FID implementation
Code taken from [FID score for PyTorch](https://github.com/mseitzer/pytorch-fid)<br>
<img src='https://drive.google.com/uc?id=11r3lyk-PUkOzSV5LFZeeZ1AYGVB7QMDa'  height='450' width='700' ><br>
[image source](https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid/notebook)

In [2]:
class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=[DEFAULT_BLOCK_INDEX],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False):
        
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        
        inception = models.inception_v3(pretrained=True)

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps
        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)
        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp
    
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3([block_idx])
model=model.cuda()

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [3]:
def calculate_activation_statistics(images,model,batch_size=128, dims=2048,
                    cuda=False):
    model.eval()
    act=np.empty((len(images), dims))
    
    if cuda:
        batch=images.cuda()
    else:
        batch=images
    pred = model(batch)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
    if pred.size(2) != 1 or pred.size(3) != 1:
        pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

    act= pred.cpu().data.numpy().reshape(pred.size(0), -1)
    
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

In [4]:
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

In [5]:
def calculate_fretchet(images_real,images_fake,model):
     mu_1,std_1=calculate_activation_statistics(images_real,model,cuda=True)
     mu_2,std_2=calculate_activation_statistics(images_fake,model,cuda=True)
    
     """get fretched distance"""
     fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
     return fid_value

Function for calculating the gradient penalty

In [6]:
def gradient_penalty(critic, labels, real, fake, device="gpu"):
    batch_size, C, H, W = real.shape
    alpha = torch.rand((batch_size, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Calculate the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Functions for loading and saving checkpoints

In [7]:
# function for saving the state of the trained discriminator and generator
def save_checkpoint(state, filename="cifar10_sagan.pth.tar"):
    if not os.path.exists("./checkpoint/{}/".format('sagan')):
        os.makedirs("./checkpoint/{}/".format('sagan'))
    print("=> Saving checkpoint....")
    torch.save(state, os.path.join("./checkpoint/{}/".format('sagan'),filename))

# function for loading the state of the trained discriminator and generator
def load_checkpoint(checkpoint, gen, critic):
    print("=> Loading checkpoint....")
    gen.load_state_dict(checkpoint['gen'])
    critic.load_state_dict(checkpoint['critic'])

## Self Attention Layer
<img src='https://drive.google.com/uc?id=1diWtHE9iXa6z2kmgUa4qTweXhqxCqr1A'  height='225' width='600' ><br>
[image source](https://proceedings.mlr.press/v97/zhang19d.html\)



In [8]:
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_

# function for applying spectral normalization on the top of convolutional layer
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))

# Self Attention Layer 
class Self_Attention_Layer(nn.Module):
    def __init__(self, in_channels):
        super(Self_Attention_Layer, self).__init__()
        self.in_channels = in_channels
        self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
        self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
        self.softmax  = nn.Softmax(dim=-1)
        self.sigma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
    # input feature map x (B X C X W X H)                          
    # out : self attention value + input feature 
    # attention: B X N X N (N is Width*Height)
        
        _, ch, h, w = x.size()
        # Theta path
        theta = self.snconv1x1_theta(x)
        theta = theta.view(-1, ch//8, h*w)
        # Phi path
        phi = self.snconv1x1_phi(x)
        phi = self.maxpool(phi)
        phi = phi.view(-1, ch//8, h*w//4)
        # Attention map
        attn = torch.bmm(theta.permute(0, 2, 1), phi)
        attn = self.softmax(attn)
        # g path
        g = self.snconv1x1_g(x)
        g = self.maxpool(g)
        g = g.view(-1, ch//2, h*w//4)
        # Attention_g
        attn_g = torch.bmm(g, attn.permute(0, 2, 1))
        attn_g = attn_g.view(-1, ch//2, h, w)
        attn_g = self.snconv1x1_attn(attn_g)
        # Output
        out = x + self.sigma*attn_g
        return out

## Definition of the Discriminator and Generator Networks
<img src='https://drive.google.com/uc?id=1FMI_cN563HG8eT4vS8P5V0AEm7BI9wBs'  height='500' width='450' ><br>
**Conditional GAN**<br>
[image source](https://cedar.buffalo.edu/~srihari/CSE676/index.html)


In [9]:
import torch
import torch.nn as nn

# discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_channels, features_d, num_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.disc = nn.Sequential(
            # input: N x img_channels x 64 x 64
            nn.Conv2d(img_channels+1, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._attention(features_d * 2),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )
        self.embed = nn.Embedding(num_classes, img_size*img_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False,)),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )
    def _attention(self, in_channels):
        return Self_Attention_Layer(in_channels)

    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim=1) # N X C X img_size(H) X img_size(W) 
        return self.disc(x)

# generator network
class Generator(nn.Module):
    def __init__(self, channels_noise, img_channels, features_g, num_classes, img_size, embed_size,):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.gen = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise + embed_size, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._attention(features_g * 4),
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x img_channels x 64 x 64
            nn.Tanh(),
        )
        self.embed = nn.Embedding(num_classes, embed_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False,)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def _attention(self, in_channels):
        return Self_Attention_Layer(in_channels)

    def forward(self, x, labels):
        # latent vector z: N X noise_dim X 1 X 1
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        return self.gen(x)

# function for initializing weights of the neural network layers
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m, (nn.Linear)):
            nn.init.xavier_uniform_(m.weight.data)

## Training of the SAGAN
<img src='https://drive.google.com/uc?id=10cuYVwPZNHm2hQJMltlQqayRjjOrkTiQ'  height='250' width='600' ><br>

<img src='https://drive.google.com/uc?id=1Lu4kCY3VITu7I9zFGj2jkGpWpjt1A3z8'  height='320' width='625' ><br>
**WGAN training**<br>[image source](https:///proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

channels_img = 3
batch_size = 64
img_size = 64
num_classes = 10
z_dim = 100
gen_embedding = 100
num_epochs = 100
critic_features = 16
gen_features = 16
disc_lr = 4e-4
gen_lr = 1e-4
critic_iterations = 5
lambda_gp = 10

transforms = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]),
    ]
)

dataset = datasets.CIFAR10(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,)

gen = Generator(z_dim, channels_img, gen_features, num_classes, img_size, gen_embedding).to(device)
print(gen)
critic = Discriminator(channels_img, critic_features, num_classes, img_size).to(device)
print(critic)

initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=gen_lr, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=disc_lr, betas=(0.0, 0.9))

if os.path.exists(os.path.join("./checkpoint/{}/".format('sagan'), "cifar10_sagan.pth.tar")):
    checkpoint = torch.load(os.path.join("./checkpoint/{}/".format('sagan'), "cifar10_sagan.pth.tar"))
    load_checkpoint(checkpoint, gen, critic)

step = 0

gen.train()
critic.train()

Gen_losses = []
Disc_losses = []
fretchet_distances=[]
iter_list = []

print("Starting Training Loop....")

for epoch in range(num_epochs):
    if epoch % 5 == 0:
        checkpoint = {'gen' : gen.state_dict(), 'critic' : critic.state_dict()}
        save_checkpoint(checkpoint)
    for batch_idx, (real, labels) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels=labels.to(device)
        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(critic_iterations):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty(critic, labels, real, fake, device=device)
            loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp)
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses 
        if batch_idx % 100 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Disc Loss : {loss_critic:.4f}, Gen loss : {loss_gen:.4f}"
            )
        # save images generated by the generator, generator and discriminator losses and fid score every 500 iterations
        if step % 500 == 0 and step > 0:
            with torch.no_grad():
                fake = gen(noise, labels)
                img_grid_fake = torchvision.utils.make_grid(fake[:64], normalize=True)
                if not os.path.exists("./outputs/{}/generated_images/".format('sagan')):
                    os.makedirs("./outputs/{}/generated_images/".format('sagan'))
                torchvision.utils.save_image(img_grid_fake, "./outputs/{}/generated_images/img_{}.png".format('sagan', str(step)),)
            iter_list.append(step)
            Gen_losses.append(loss_gen.item())
            Disc_losses.append(loss_critic.item())
            fretchet_distance=calculate_fretchet(real,fake,model)
            fretchet_distances.append(fretchet_distance.item())

            losses = pd.DataFrame({"Iteration No.": iter_list, "Generator loss": Gen_losses, "Discriminator loss": Disc_losses, "FID Score": fretchet_distances})

            if not os.path.exists("./outputs/{}/".format('sagan')):
                os.makedirs("./outputs/{}/".format('sagan'))
            losses.to_csv("./outputs/{}/output_data.csv".format('sagan'))

        step += 1

Files already downloaded and verified
Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(200, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Self_Attention_Layer(
      (snconv1x1_theta): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
      (snconv1x1_phi): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
      (snconv1x1_g): Conv2d(64, 32, kernel_size=(1, 1), stri

KeyboardInterrupt: ignored

In [None]:
# loading data from output_data.csv for plotting graphs 
df = pd.read_csv("./outputs/{}/output_data.csv".format('sagan'))
Iterations = df['Iteration No.']
Gen_Loss = df['Generator Loss']
Disc_Loss = df['Discriminator Loss']
FID = df['FID Score']

In [None]:
# plotting the generator losses
plt.xlabel('Iterations')
plt.ylabel('Generator Loss')
plt.plot(Iterations, Gen_Loss)

In [None]:
# plotting the discriminator losses
plt.xlabel('Iterations')
plt.ylabel('Discriminator Loss')
plt.plot(Iterations, Disc_Loss)

In [None]:
# plotting the FID scores
plt.xlabel('Iterations')
plt.ylabel('FID Score')
plt.plot(Iterations, FID)