In [1]:
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
import torch.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
root = "C:\\Users\\shant\\celeba"
LEARNING_RATE = 0.00001
BATCH_SIZE = 64
Z_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 64
IMAGE_CHANNELS = 3
NUM_EPOCHS = 5
IMAGE_SIZE = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

dataset = datasets.ImageFolder(root = root , transform=transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))

dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)

In [4]:
# Implementation of the Wasserstein GAN
# WGAN Has better stability and loss means something for WGAN: Its a termination criteria
# WGAN Also prevents Mode Collapse(Model only outputs specific classes)
# When Discriminator converged to 0 obtained great results

class Generator(nn.Sequential):
    """
    z_dim: 
    channels_img: Input channels(for example for an RGB image this value is 3)
    features_g: Size of the output feature map(In this case its 64x64)
    """
    def __init__(self, z_dim, channels_img, features_g):
        
        modules = [self._block(z_dim, features_g*16, 4, 1, 0),
                   self._block(features_g*16, features_g*8, 4, 2, 1),
                   self._block(features_g*8, features_g*4, 4, 2, 1),
                   self._block(features_g*4, features_g*2, 4, 2, 1),
                   nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1),
                   nn.Tanh()]
        
        super(Generator, self).__init__(*modules)
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
class Critic(nn.Sequential):
    
    def __init__(self, channels_img, features_d):
        
        modules = [nn.Conv2d(channels_img, features_d, kernel_size = 4, stride = 2, padding = 1), #32x32
                   nn.LeakyReLU(0.2, inplace=True),
                   self._block(features_d, features_d*2, 4, 2, 1),# 16x16
                   self._block(features_d*2, features_d*4, 4, 2, 1), #8x8
                   self._block(features_d*4, features_d*8, 4, 2, 1), #4x4
                   nn.Conv2d(features_d*8, 1, kernel_size = 4, stride = 2, padding = 0)]
        
        super(Critic, self).__init__(*modules)
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
         nn.Conv2d(in_channels, out_channels,kernel_size, stride, padding, bias=False),
            nn.InstanceNorm2d(out_channels, affine = True), # Learnable Parameters
            nn.LeakyReLU(0.2, inplace=True)
        )

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1, 0.02)
        nn.init.constant_(m.bias.data, 0)

def gradient_penalty(critic, real, fake, device = "cpu"):
    batch_size, C, H, W = real.shape
    # Creating interpolated images
    epsilon = torch.randn([batch_size, 1, 1, 1]).repeat(1,C,H,W).to(device)
    interpolated_images = epsilon*real + (1-epsilon) * fake

    #calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Compute the gradients with respect to the interpolated images, just need the first value
    gradient = torch.autograd.grad(
        inputs = interpolated_images, 
        outputs = mixed_scores, 
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True)[0]

    # Number of Dimension
    # Calculate the Norm of the gradient to Eforce 1-Lipschitz Constraint
    gradient = gradient.view(gradient.size(0), -1)
    gradient_norm = gradient.norm(2, dim = 1)
    gradient_penalty = torch.mean((gradient_norm - 1)**2)

    return gradient_penalty

# def test():
#     N, in_channels, H, W = 8, 3, 64, 64
#     z_dim = 100
#     X = torch.randn((N, in_channels, H, W))
#     disc = Critic(in_channels,8)
#     disc.apply(weights_init)
#     assert disc(X).shape == (N, 1, 1, 1) # One Value per example
#     gen = Generator(z_dim, in_channels, 64)
#     gen.apply(weights_init)
#     z = torch.randn((N, z_dim, 1, 1))
#     assert gen(z).shape == (N, in_channels, H, W) # Ouput Generated image
#     print("Success")
# test()

In [6]:
########################
# Generator and Discriminator Model objects
########################
generator = Generator(Z_DIM,IMAGE_CHANNELS,FEATURES_GEN).to(device)
critic = Critic(IMAGE_CHANNELS,FEATURES_DISC).to(device)

########################
# Weight Initialization for the model
########################
generator.apply(weights_init)
critic.apply(weights_init)

########################
# Optimizers for Critic and the Generator
########################
optimizer_gen = optim.Adam(generator.parameters(), lr = LEARNING_RATE, betas = (0,0.9))
optimizer_critic = optim.Adam(critic.parameters(), lr = LEARNING_RATE, betas = (0,0.9))

#######################
# Create tensorboard SummaryWriter objects to display generated fake images and associated loss curves
#######################
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
loss_curves = SummaryWriter(f"logs/loss_curves")

#######################
# Create a batch of latent vectors. Will be used to to do a single pass through the generator after 
# the training has terminated
#######################
fixed_noise = torch.randn((64, Z_DIM, 1, 1)).to(device)

step = 0 # For printing to tens

In [7]:
for epoch in range(NUM_EPOCHS):
    
    # Unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        
        # The real world images
        real = real.to(device)
        
        cur_batch_size = real.shape[0]
        #####################################################
        # Train the Critic
        #####################################################
        
        for _ in range(CRITIC_ITERATIONS):
            critic.zero_grad()
            # Latent noise
            noise = torch.randn((cur_batch_size, Z_DIM, 1, 1)).to(device)
            # Pass the latent vector through the generator
            fake = generator(noise)     
            critic_real = critic(real).view(-1)
            critic_fake = critic(fake).view(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            ## Loss for the critic. Taking -ve because RMSProp are designed to minimize 
            ## Hence to minimize something -ve is equivalent to maximizing that expression
            loss_critic = -torch.mean(critic_real) + torch.mean(critic_fake) + LAMBDA_GP*gp 
            loss_critic.backward(retain_graph=True)
            optimizer_critic.step()

        #############################
        # Train the generator minimizing -E[critic(gen_fake)]
        #############################
        generator.zero_grad()
        output = critic(fake).view(-1)
        loss_gen = -torch.mean(output)
        loss_gen.backward()
        optimizer_gen.step()
        
        if batch_idx % 50 == 0:
            
            print(
            f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {-loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )
            with torch.no_grad():
                fake = generator(fixed_noise)
            
                # The [:64] prints out the 4-D tensor BxCxHxW
                img_grid_real = torchvision.utils.make_grid(
                    real[:64], normalize = True)
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:64], normalize = True)
                ##########################
                # TensorBoard Visualizations
                ##########################
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
#                 loss_curves.add_scalar("generator", {loss_gen, global_step=step)
                loss_curves.add_scalars("curves", {
                    "generator":loss_gen, "critic":-(loss_critic)
                }, global_step = step)
#                 loss_curves.add_scalar("discriminator", loss_disc, global_step = step)
                
            step += 1 # See progression of images

# Save critic and generator state dictionaries
torch.save(generator.state_dict(), 'generator.pt')
torch.save(critic.state_dict(), 'critic.pt')

Epoch [0/5] Batch 0/3166                   Loss D: -215.7560, loss G: -1.8182
Epoch [0/5] Batch 50/3166                   Loss D: 29.8819, loss G: 23.8634
Epoch [0/5] Batch 100/3166                   Loss D: 70.4569, loss G: 52.4091
Epoch [0/5] Batch 150/3166                   Loss D: 86.0928, loss G: 64.4262
Epoch [0/5] Batch 200/3166                   Loss D: 106.8502, loss G: 75.5788
Epoch [0/5] Batch 250/3166                   Loss D: 103.8101, loss G: 79.9860
Epoch [0/5] Batch 300/3166                   Loss D: 66.9414, loss G: 75.8155
Epoch [0/5] Batch 350/3166                   Loss D: 78.5758, loss G: 70.8416
Epoch [0/5] Batch 400/3166                   Loss D: 74.5777, loss G: 69.2469
Epoch [0/5] Batch 450/3166                   Loss D: 57.6903, loss G: 64.7120
Epoch [0/5] Batch 500/3166                   Loss D: 47.7621, loss G: 52.6940
Epoch [0/5] Batch 550/3166                   Loss D: 42.5687, loss G: 42.6688
Epoch [0/5] Batch 600/3166                   Loss D: 34.7462, l

Epoch [1/5] Batch 2000/3166                   Loss D: 14.9621, loss G: 63.4083
Epoch [1/5] Batch 2050/3166                   Loss D: 14.5861, loss G: 62.8509
Epoch [1/5] Batch 2100/3166                   Loss D: 13.8006, loss G: 64.3905
Epoch [1/5] Batch 2150/3166                   Loss D: 13.9927, loss G: 62.7234
Epoch [1/5] Batch 2200/3166                   Loss D: 15.1381, loss G: 63.9889
Epoch [1/5] Batch 2250/3166                   Loss D: 11.6759, loss G: 65.4801
Epoch [1/5] Batch 2300/3166                   Loss D: 14.0205, loss G: 62.5690
Epoch [1/5] Batch 2350/3166                   Loss D: 13.1323, loss G: 70.3030
Epoch [1/5] Batch 2400/3166                   Loss D: 14.3723, loss G: 67.4497
Epoch [1/5] Batch 2450/3166                   Loss D: 14.0902, loss G: 66.4464
Epoch [1/5] Batch 2500/3166                   Loss D: 12.4679, loss G: 66.7262
Epoch [1/5] Batch 2550/3166                   Loss D: 14.3514, loss G: 67.9610
Epoch [1/5] Batch 2600/3166                   Loss D

Epoch [3/5] Batch 800/3166                   Loss D: 11.7285, loss G: 82.5719
Epoch [3/5] Batch 850/3166                   Loss D: 10.8758, loss G: 83.8583
Epoch [3/5] Batch 900/3166                   Loss D: 13.5974, loss G: 83.3039
Epoch [3/5] Batch 950/3166                   Loss D: 11.4693, loss G: 86.2271
Epoch [3/5] Batch 1000/3166                   Loss D: 11.6608, loss G: 85.6967
Epoch [3/5] Batch 1050/3166                   Loss D: 11.5968, loss G: 84.3653
Epoch [3/5] Batch 1100/3166                   Loss D: 11.3065, loss G: 85.7727
Epoch [3/5] Batch 1150/3166                   Loss D: 11.9035, loss G: 85.1466
Epoch [3/5] Batch 1200/3166                   Loss D: 12.3141, loss G: 82.8839
Epoch [3/5] Batch 1250/3166                   Loss D: 11.9012, loss G: 83.6463
Epoch [3/5] Batch 1300/3166                   Loss D: 11.8969, loss G: 85.9413
Epoch [3/5] Batch 1350/3166                   Loss D: 11.6255, loss G: 84.9797
Epoch [3/5] Batch 1400/3166                   Loss D: 10

Epoch [4/5] Batch 2800/3166                   Loss D: 9.5222, loss G: 97.1319
Epoch [4/5] Batch 2850/3166                   Loss D: 10.8799, loss G: 95.9715
Epoch [4/5] Batch 2900/3166                   Loss D: 12.7303, loss G: 97.8696
Epoch [4/5] Batch 2950/3166                   Loss D: 11.6509, loss G: 97.9898
Epoch [4/5] Batch 3000/3166                   Loss D: 8.6593, loss G: 98.6911
Epoch [4/5] Batch 3050/3166                   Loss D: 10.3881, loss G: 94.9838
Epoch [4/5] Batch 3100/3166                   Loss D: 10.9699, loss G: 93.7298
Epoch [4/5] Batch 3150/3166                   Loss D: 8.9403, loss G: 99.4433
