# Basic GAN

Just a simple good old original GAN for me to refresh my memory

In [2]:
import pdb
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np

In [3]:
import torch
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid


  warn(


## Utilities functions

In [4]:
#visualization function

def show_images(tensor,channels=1,size=(28,28),num=16):
    """
    tensor = 128 x 784
    """
    data = tensor.detach().cpu().view(-1,channels,*size) #128x1x28x28
    
    grid_original = make_grid(data[:num],nrow=int(np.sqrt(num))) # num x 1 x 28 x 28
    grid = grid_original.permute(1, 2, 0) # num x 28 x 28 x 1
    
    plt.imshow(grid)
    plt.show()

In [5]:
def preprocessing_tranform():
    
    transform = transforms.ToTensor()
    
    return transform

## Parameters and configs

In [6]:
# setup main params and hyper-params

device = "cuda:0"

EPOCHS = 1000
LEARNING_RATE = 0.00001
BATCH_SIZE = 256
Z_DIM = 64

loss_function = nn.BCEWithLogitsLoss()

current_step = 0
show_every = 1000
show_ims_every = 1000

mean_generator_loss = 0
mean_discriminator_loss = 0


## Download data and prepare loader

In [7]:
data_loader = DataLoader(MNIST(".", download = True, 
                               transform = preprocessing_tranform()),
                         shuffle=True,batch_size = BATCH_SIZE)

## Generator

In [8]:
def genBlock(input_dim, output_dim):
    """
        input_dim = input size
        output_dim = output_size
    """
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

In [9]:


class Generator(nn.Module):
    
    def __init__(self, z_dim=64, image_dim = 784, hidden_dim = 128):
        """
        z_dim = size of the random input vector
        image_dim = size of the flattened images
        hidden_dim = base size of the hidden layers
        """
        super().__init__()
        
        self.gen = nn.Sequential(
            genBlock(z_dim, hidden_dim), #z_dim x hidden_dim,
            genBlock(hidden_dim, hidden_dim*2),
            genBlock(hidden_dim*2, hidden_dim*4),
            genBlock(hidden_dim*4, hidden_dim*8),
            nn.Linear(hidden_dim*8, image_dim),
            
            # make the output between 0 and 1 (black and white images)
            nn.Sigmoid(),
        )
        
    def forward(self, noise):
        """
        noise = random noise vector
        """
        return self.gen(noise)
    
def generate_noise(number, z_dim):
    """
    number = number of random vectors
    z_dim = size of every random_vector
    """
    return torch.randn(number, z_dim).to(device)

## Discriminator

In [10]:
def discBlock(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2)
    )

In [11]:
class Discriminator(nn.Module):
    
    def __init__(self, input_dim = 784,hidden_dim=256):
        
        super().__init__()
        
        self.disc = nn.Sequential(
            discBlock(input_dim, hidden_dim*4), 
            discBlock(hidden_dim*4, hidden_dim*2), 
            discBlock(hidden_dim*2, hidden_dim),
            # 1 output that classifies an input image as real or fake
            nn.Linear(hidden_dim, 1), # input_dim x 1
        )
        
    def forward(self, image):
        """
        image = input image to be classified as real or fake
        """
        return self.disc(image)


## Define models and optimizers

In [12]:
generator = Generator(z_dim=Z_DIM,hidden_dim=256).to(device)
generator_optimizer = torch.optim.Adam(generator.parameters(), lr = LEARNING_RATE)

AssertionError: Torch not compiled with CUDA enabled

In [None]:
generator

In [None]:
discriminator = Discriminator(hidden_dim=256).to(device)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr = LEARNING_RATE)

In [None]:
discriminator

## Check everything is correct
* Dimensions
* Initial generator output(should be random)

In [None]:
x,y = next(iter(data_loader))
print(x.shape, y.shape)

In [None]:
noise = generate_noise(BATCH_SIZE,Z_DIM)
fake_images = generator(noise)

show_images(fake_images)

Output is pure noise as expected

## Defining loss functions


In [None]:
# generator loss
def generator_forward_and_loss_calculate(loss_function, generator, discriminator, number, z_dim):
    """Takes random noise and generates fake images, then calculate loss for those generated ims"""
    
    noise = generate_noise(number, z_dim)
    fake_images = generator(noise)
    # pass fake images and get discriminator score(probability  of being real)
    ## we want to maximize this score
    generated_ims_pred = discriminator(fake_images)
    
    #calcualate the generator loss by comparing scores to 1s(1s means real)
    targets = torch.ones_like(generated_ims_pred)
    generator_loss = loss_function(generated_ims_pred, targets)
    
    return generator_loss

In [None]:
# generator loss
def discriminator_forward_and_loss_calculate(loss_function, generator, discriminator, number, 
                                             real_images,z_dim):
    """takes real images and random noise to forward through the discriminator and calculate its loss"""
    
    real_ims_pred = discriminator(real_images)
    
    noise = generate_noise(number, z_dim)
    fake_images = generator(noise)
    # pass fake images and get discriminator score(probability  of being real)
    ## we want to minimize this score
    generated_ims_pred = discriminator(fake_images.detach())
    
    #calculate the discriminator loss for fake images by comparing scores to 0s(0s means fake)
    discriminator_fake_targets = torch.zeros_like(generated_ims_pred)
    disc_fake_loss = loss_function(generated_ims_pred, discriminator_fake_targets)
    
    # calculate discriminator loss for real images by comparing scores to 1s(1s means real)
    discriminator_real_targets = torch.ones_like(real_ims_pred)
    disc_real_loss = loss_function(real_ims_pred, discriminator_real_targets)
    
    # total loss is average between real and fake losses
    total_loss = (disc_fake_loss + disc_real_loss)/2
    
    return total_loss

## Training 

In [None]:
steps_per_epoch = data_loader.dataset.data.data.size()[0]/BATCH_SIZE
show_ims_every = int(steps_per_epoch)*10 # to show every k epochs, this is temporary for overwriting specifying the num of steps
print(f"{steps_per_epoch} steps with {BATCH_SIZE} per step")

In [None]:
disc_losses = []
gen_losses = []

In [None]:

for epoch in range(EPOCHS):
    for real_ims, _ in tqdm(data_loader):
        ### discriminator
        discriminator_optimizer.zero_grad()
        
        current_batch_size = len(real_ims) #last step contains less images
        real_ims = real_ims.view(current_batch_size, -1)
        real_ims = real_ims.to(device)
        
        discriminator_loss = discriminator_forward_and_loss_calculate(loss_function, generator, discriminator,
                                        current_batch_size, real_ims, Z_DIM
                                                                     )
        
        discriminator_loss.backward(retain_graph = True)
        discriminator_optimizer.step()
        
        ### generator
        generator_optimizer.zero_grad()
        generator_loss = generator_forward_and_loss_calculate(loss_function, generator, discriminator,
                                        current_batch_size, Z_DIM                     
                                                            )
        generator_loss.backward(retain_graph = True)
        generator_optimizer.step()
        
        ### output and feedback
        mean_disc_loss = discriminator_loss.item()
        mean_gen_loss = generator_loss.item()
        
        disc_losses.append(mean_disc_loss)
        gen_losses.append(mean_generator_loss)
        
        if current_step > 0 and current_step % show_ims_every == 0:
            fake_noise = generate_noise(current_batch_size, Z_DIM)
            fake_ims = generator(fake_noise)
            
            plt.plot(gen_losses)
            plt.plot(disc_losses)
            
            show_images(fake_ims)
            show_images(real_ims)
            
            plt.show()
            
        if current_step > 0 and current_step % show_every == 0:            
            print(f"{epoch}: step {current_step}, gen loss:{mean_gen_loss}, disc loss: {mean_disc_loss}")
            
        current_step += 1

In [14]:
torch.cuda.is_available()

False