# implimentation of a basic GAN
- imports
- load data
- create the Generator and Discriminator classes
- set up the hyperparameters, loss function, and optimizers
- define loss functions for Discrinimator, Generator
- define training loop 
- evaluation


### imports

In [None]:
import torch
# Set random seed for reproducibility
torch.manual_seed(123)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(0)


### load data


In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
mnist = MNIST(root='.', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)

### create the Generator and Discriminator classes

In [None]:
def get_noise(n_samples, z_dim, device=device):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim),
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples,z_dim).to(device)

In [None]:
# Define the Generator
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        img_shape: [H,W] of the input image as input_image.shape
        z_dim: the dimension of the noise vector, a scalar
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, img_shape, z_dim=10, hidden_dim=128):

        super(Generator, self).__init__()
        self.im_dim = int(np.prod(img_shape))

        def gen_block(in_feat, out_feat):
            return nn.Sequential(
                        nn.Linear(in_feat, out_feat),
                        nn.BatchNorm1d(out_feat, 0.8),
                        nn.ReLU(0.2, inplace=True)
                    )

        self.model = nn.Sequential(
            gen_block(z_dim, hidden_dim),
            gen_block(hidden_dim, hidden_dim * 2),
            gen_block(hidden_dim * 2, hidden_dim * 4),
            gen_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, self.im_dim),
            nn.Tanh()
        )

    def forward(self, noise):
        return self.model(noise)


In [None]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_shape, hidden_dim=128):
        super(Discriminator, self).__init__()

        def dis_block(in_feat, out_feat):
            return nn.Sequential(
                        nn.Linear(in_feat, out_feat),
                        nn.LeakyReLU(0.2)
                    )

        self.model = nn.Sequential(
            dis_block(int(np.prod(img_shape)), hidden_dim*4),
            dis_block(hidden_dim*4, hidden_dim*2),
            dis_block(hidden_dim*2, hidden_dim),
            nn.Linear(hidden_dim,1)
            # nn.Sigmoid() not needed for the loss function 
        )

    def forward(self, image):
        return self.model(image)

### set up hyperparameters

In [None]:
# Hyperparameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
img_shape = [28, 28]

# Initialize generator, discriminator, and their optimizer
gen = Generator(img_shape, z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(img_shape).to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

### define loss function

In [None]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    '''
    Return the loss of the discriminator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        real: a batch of real images
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    '''

    # generate a batch of num_images number of fake images
    noise = get_noise(num_images, z_dim).to(device)
    # generate the fake image via generator
    fake_img = gen(noise).to(device)
    # detach the fake image from generator and then make discriminator predictions
    # (cal discriminator's loss, don't want gradient to flow back to generator)
    pred_fake = disc(fake_img.detach()).to(device)
    # calculate loss
    # discriminator predict true=1 fake=0 and want's to be correct
    loss_fake = criterion(pred_fake, torch.zeros_like(pred_fake)).to(device)
    # same prediction and loss for label=1
    pred_true = disc(real).to(device)
    loss_true = criterion(pred_true, torch.ones_like(pred_true)).to(device)
    # calculate average loss
    disc_loss = torch.mean(torch.stack((loss_fake, loss_true))).to(device)
    
    return disc_loss

In [None]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        gen_loss: a torch scalar loss value for the current batch
    '''
    
    # generate a batch of num_images number of fake images
    noise = get_noise(num_images, z_dim).to(device)
    # generate the fake image via generator
    fake_img = gen(noise).to(device)
    # !!! remember to remove .detach() from fake image before prediction so 
    # that the gradient will flow back to generator
    pred_fake = disc(fake_img).to(device)
    gen_loss = criterion(pred_fake, torch.ones_like(pred_fake)).to(device)
    
    return gen_loss

### training loop and evaluation

In [None]:
# training params
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0

test_generator = True 
display_step = 500

gen_loss = False
error = False

def plot_tensor_of_imgs(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    plot a tensor of images, based on number of images, and size per image,
    in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:


# Training loop
for epoch in range(n_epochs):

    for batch, (imgs, _) in enumerate(dataloader):

        current_batch_size = len(imgs)
        
        # Flatten the batch of real images from the data
        real = real.view(current_batch_size, -1).to(device)

        #  ######### Train Discriminator #########
        # zero out the gradients before backprop
        disc_opt.zero_grad()
        # calc discriminator loss
        disc_loss = get_disc_loss(gen, 
                                  disc, 
                                  criterion, 
                                  real, 
                                  current_batch_size, 
                                  z_dim, 
                                  device)
        # update gradient 
        disc_loss.backward(retain_graph=True)
        # update optimizer
        disc_opt.step()

        #  ######### Train Generator #########
        # to test the generator, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()
        # zero out the gradients before backprop
        gen_opt.zero_grad()
        # calc generator loss
        gen_loss = get_gen_loss(gen, 
                                disc, 
                                criterion, 
                                current_batch_size, 
                                z_dim, 
                                device)
        # backprop update gradients
        gen_loss.backward(retain_graph=True)
        # update optimizer
        gen_opt.step()

        # to test the generator, to keep track of the generator weights
        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")
        
        # ######### print progress #########
        if epoch % 100 == 0:
            print(
                f"[Epoch {epoch}/{n_epochs}] [Batch {batch}/{len(dataloader)}] "
                f"[D loss: {disc_loss.item():.4f}] [G loss: {gen_loss.item():.4f}]"
            )

        # ######### visualize progress #########
        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        # visualize training progress
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(current_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            plot_tensor_of_imgs(fake)
            plot_tensor_of_imgs(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1


### save for later eval/inference

In [None]:
from pathlib import Path

# mkdir if not existing yet
mdl_path = Path("models")
mdl_path.mkdir(parents=True, # create parent directories if needed
               exist_ok=True # if models directory already exists, don't error
)

# Create model save path
mdl_name_gen = "basic_GAN_gen.pth"
mdl_name_disc = "basic_GAN_disc.pth"

# Save the model state dict
print(f"Saving gen-disc models to: {mdl_path}")
torch.save(obj=gen.state_dict(), # only saving the state_dict() only saves the learned parameters
           f=f"{mdl_path} / {mdl_name_gen}")
torch.save(obj=disc.state_dict(), # only saving the state_dict() only saves the learned parameters
           f=f"{mdl_path} / {mdl_name_disc}")