<a href="https://colab.research.google.com/github/deepeshhada/SA-GAN/blob/master/Train%20-%20SAGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import math
import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.optim.optimizer import Optimizer, required
from torch.autograd import Variable
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

In [0]:
batch_size = 8
image_size = 64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
train_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)
trainloader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data


In [0]:
# Spectral Normalization

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

In [0]:
# Model Hyperparameters
z_dim = 100

In [0]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()        
        
        self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input):
        B, C, width, height = input.size()
        N = width*height

        query_transpose = self.query_conv(input).view(B, -1, N).permute(0, 2, 1) # B X N X C
        key = self.key_conv(input).view(B, -1, N) # B X C X N
        
        attention_unnormalized = torch.bmm(query_transpose, key) # batch matmul
        attention = self.softmax(attention_unnormalized) # B X N X N
        
        value = self.value_conv(input).view(B, -1, N) # B X C X N

        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(B, C, width, height)
        out = self.gamma * out + input # add input feature maps to self attention

        return out, attention

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        layer1 = []
        layer2 = []
        layer3 = []
        layer4 = []
        layer5 = []
        # 100 X 1 X 1
        layer1.append(SpectralNorm(nn.ConvTranspose2d(in_channels=100, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False)))
        layer1.append(nn.BatchNorm2d(512))
        layer1.append(nn.ReLU(inplace=True))
        # 512 X 4 X 4
        layer2.append(SpectralNorm(nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False)))
        layer2.append(nn.BatchNorm2d(256))
        layer2.append(nn.ReLU(inplace=True))
        # 256 X 8 X 8
        layer3.append(SpectralNorm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)))
        layer3.append(nn.BatchNorm2d(128))
        layer3.append(nn.ReLU(inplace=True))
        # 128 X 16 X 16
        layer4.append(SpectralNorm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)))
        layer4.append(nn.BatchNorm2d(64))
        layer4.append(nn.ReLU(inplace=True))
        # 64 X 32 X 32
        layer5.append(nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False))
        layer5.append(nn.Tanh())
        # 3 x 64 X 64

        self.layer1 = nn.Sequential(*layer1)
        self.layer2 = nn.Sequential(*layer2)
        self.layer3 = nn.Sequential(*layer3)
        self.attention1 = SelfAttention(in_channels=128)
        self.layer4 = nn.Sequential(*layer4)
        self.attention2 = SelfAttention(in_channels=64)
        self.layer5 = nn.Sequential(*layer5)


    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1) # batch=8 X z_dim=100 X 1 X 1
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out, map1 = self.attention1(out)
        out = self.layer4(out)
        out, map2 = self.attention2(out)
        out = self.layer5(out)

        return out

G = Generator().to(device)

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        layer1 = []
        layer2 = []
        layer3 = []
        layer4 = []
        layer5 = []
        
        # 3 X 64 X 64
        layer1.append(SpectralNorm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)))
        layer1.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        # 64 X 32 X 32
        layer2.append(SpectralNorm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)))
        layer2.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        # 128 X 16 X 16
        layer3.append(SpectralNorm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False)))
        layer3.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        # 256 X 8 X 8
        layer4.append(SpectralNorm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False)))
        layer4.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        # 512 X 4 X 4
        layer5.append(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False))
        layer5.append(nn.Sigmoid())

        self.layer1 = nn.Sequential(*layer1)
        self.layer2 = nn.Sequential(*layer2)
        self.layer3 = nn.Sequential(*layer3)
        self.attention1 = SelfAttention(in_channels=256)
        self.layer4 = nn.Sequential(*layer4)
        self.attention2 = SelfAttention(in_channels=512)
        self.layer5 = nn.Sequential(*layer5)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out, map1 = self.attention1(out)
        out = self.layer4(out)
        out, map2 = self.attention2(out)
        out = self.layer5(out)

        return out

D = Discriminator().to(device)

In [0]:
# Optimizer Hyperparameters

lr_g = 0.0001
lr_d = 0.0004
lr_decay = 0.95
beta1 = 0.5
beta2 = 0.9
opt_D = optim.Adam(D.parameters(), lr=lr_d, betas=(beta1, beta2))
opt_G = optim.Adam(G.parameters(), lr=lr_g, betas=(beta1, beta2))
max_epochs = 40
pre_trained = True

In [0]:
model_path = "./drive/My Drive/saved models/sagan"

if pre_trained:
    torch.manual_seed(0)
    G.load_state_dict(torch.load(model_path + " - generator"))
    D.load_state_dict(torch.load(model_path + " - discriminator"))

real_label = 1
fake_label = 0
loss_function = loss_function = nn.BCELoss()

fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)
losses_G = []
losses_D = []
img_list = []
iters = 0

for epoch in range(0, max_epochs):
    for i, data in enumerate(trainloader, 0):
        #   Step 1.1: Train Discriminator with minibatch of only real samples
        D.zero_grad()

        real_inputs = data[0].to(device)
        real_labels = torch.full((real_inputs.size(0), ), real_label, device=device, dtype=None)
        real_outputs = D(real_inputs).to(device)
        err_D_real = loss_function(real_outputs, real_labels)
        err_D_real.backward()
        D_x = real_outputs.mean().item() # D(x), where x is a real image

        #   Step 1.2: Train Discriminator with minibatch of only fake samples
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_inputs = G(z).to(device)
        fake_labels = torch.full((fake_inputs.size(0), ), fake_label, device=device, dtype=None)
        fake_outputs = D(fake_inputs).to(device)
        err_D_fake = loss_function(fake_outputs, fake_labels)
        err_D_fake.backward(retain_graph=True)
        D_G_z1 = fake_outputs.mean().item()

        err_D = err_D_real + err_D_fake
        opt_D.step()

        #   Step 2: Train Generator with minibatch of fake samples
        G.zero_grad()
        fake_labels = torch.full((fake_inputs.size(0), ), real_label, device=device, dtype=None) # real labels = 1 are fake labels for generator
        fake_outputs = D(fake_inputs).to(device)
        # err_G = - fake_outputs.mean()
        err_G = loss_function(fake_outputs, fake_labels)
        err_G.backward()
        D_G_z2 = fake_outputs.mean().item()
        opt_G.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, max_epochs, i, len(trainloader),
                     err_D.item(), err_G.item(), D_x, D_G_z1, D_G_z2))
            
        losses_G.append(err_G.item())
        losses_D.append(err_D.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 1000 == 0) or ((epoch == max_epochs-1) and (i == len(trainloader)-1)):
            with torch.no_grad():
                fake = G(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
            if os.path.exists(model_path + " - generator"):
                os.remove(model_path + " - generator")
            if os.path.exists(model_path + " - discriminator"):
                os.remove(model_path + " - discriminator")
            torch.save(G.state_dict(), model_path + " - generator")
            torch.save(D.state_dict(), model_path + " - discriminator")

        iters += 1

In [0]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [0]:
if os.path.exists(model_path + " - generator"):
    os.remove(model_path + " - generator")
if os.path.exists(model_path + " - discriminator"):
    os.remove(model_path + " - discriminator")
torch.save(G.state_dict(), model_path + " - generator")
torch.save(D.state_dict(), model_path + " - discriminator")

In [0]:
noise = torch.randn(64, z_dim, 1, 1, device=device)

with torch.no_grad():
    fake = G(noise).detach().cpu()
img = vutils.make_grid(fake, padding=2, normalize=True)

figs = plt.figure(figsize=(8,8))
plt.axis("off")
imss = [plt.imshow(np.transpose(img, (1,2,0)), animated=False)]