In [1]:
import matplotlib.pyplot as plt
import os
import glob
import pandas as pd
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import wandb
import imageio
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
from torch.utils.data import DataLoader, Dataset
from copy import deepcopy
from torch.autograd import Variable
from tqdm import tqdm
from pprint import pprint
from PIL import Image
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import datasets, transforms
import os

In [2]:
DATASET_PATH = "./mnist_data"
class Config:
    RUNTIME = "KAGGLE"
    SUBSET_FRACTION = None
    NUM_EPOCHS = 30
    BATCH_SIZE = 64
    PRECISION = "16-mixed"
    LOG_EVERY_N_STEPS = 10
    INPUT_IMAGE_SIZE = (45, 45)    
    DIM_Z = 100
    NUM_WORKERS = mp.cpu_count()
    EARLY_STOPPING_PATIENCE = 6  # Add this line for early stopping patience
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Training on', Config.DEVICE)

Training on cuda


In [3]:
class WandbConfig:
    WANDB_KEY = ""
    WANDB_RUN_NAME = "vanilla_gan_testrun"
    WANDB_PROJECT = "vanilla_gan"
    USE_WANDB = False    

In [4]:
if Config.RUNTIME == "KAGGLE":
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    WANDB_KEY = user_secrets.get_secret("wandb")
    DATASET_PATH = "/kaggle/working/mnist_data"

if WandbConfig.USE_WANDB:
    # Log in to W&B
    os.environ["WANDB_API_KEY"] = WANDB_KEY
    # Initialize W&B
    wandb.init(project=WandbConfig.WANDB_PROJECT, name=WandbConfig.WANDB_RUN_NAME)

ModuleNotFoundError: No module named 'kaggle_secrets'

In [3]:
# Set transform to normalize data
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,))])

# each row in mnist dataset is a tuple of (image, label)
mnist_data = datasets.MNIST(root=DATASET_PATH, train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=mnist_data, 
                                          batch_size=Config.BATCH_SIZE, 
                                          shuffle=True)

In [6]:
mnist_data

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./mnist_data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [None]:
class GeneratorNet(nn.Module):
    def __init__(self, dim_z):
        super(GeneratorNet, self).__init__()
        self.dim_z = dim_z
        self.model = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.Linear(self.dim_z, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28 * 28), # 28*28 = 784 (MNIST image size)
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [None]:
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        output = self.model(flattened)
        return output

In [None]:
# Function to display input and output images
def display_images(input_img, output_img, epoch):    
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    axes[0].imshow(input_img)    
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[1].imshow(output_img)    
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    plt.show()

In [None]:
class GANMonitorCallback(Callback):
    def __init__(self, model, validation_z, log_images=True):
        self.model = model
        self.validation_z = validation_z
        self.log_images = log_images
        self.epoch_g_loss = 0.0
        self.epoch_d_loss = 0.0
        self.epoch_batches = 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        self.epoch_g_loss += outputs["g_loss"]
        self.epoch_d_loss += outputs["d_loss"]
        self.epoch_batches += 1

    def on_train_epoch_end(self, trainer, pl_module):
        avg_g_loss = self.epoch_g_loss / self.epoch_batches
        avg_d_loss = self.epoch_d_loss / self.epoch_batches
        print(f"Epoch {trainer.current_epoch} - Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")
        self.epoch_g_loss = 0.0
        self.epoch_d_loss = 0.0
        self.epoch_batches = 0
        # Generate and plot images if enabled
        if self.log_images:
            self.visualize_images(trainer.current_epoch)

    def visualize_images(self, epoch):
        """Generate and display images."""
        self.model.eval()
        with torch.no_grad():
            generated_imgs = self.model(self.validation_z.to(self.model.device))
        self.model.train()

        # Create figure
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 2, 1)
        plt.title(f'Generated Images (Epoch {epoch})')
        img_grid = torchvision.utils.make_grid(generated_imgs.cpu(), normalize=True, nrow=8)
        plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy(), cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.show()        

In [None]:
class MnistGAN(pl.LightningModule):
    def __init__(self, dim_z, lr=0.0002, b1=0.5, b2=0.999):
        super(MnistGAN, self).__init__()
        self.save_hyperparameters()
        # hyperparameters
        self.dim_z = dim_z
        self.lr = lr
        self.b1 = b1
        self.b2 = b2        
        # networks
        self.discriminator = DiscriminatorNet()
        self.generator = GeneratorNet(self.dim_z)
        # losses
        self.adversarial_loss = nn.BCEWithLogitsLoss()        
        self.validation_z = torch.randn(4, self.dim_z)
        self.automatic_optimization = False

    def forward(self, z):
        return self.generator(z)
    
    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        # Get optimizers
        opt_g, opt_d = self.optimizers()
        # train generator        
        # sample noise
        z = torch.randn(imgs.shape[0], self.dim_z, device=Config.DEVICE)
        # clear gradients from previous step 
        opt_g.zero_grad()
        # generate images (forward pass)
        generated_imgs = self(z)
        # In the training of the generator, the goal is to fool the discriminator into thinking that the generated images are real.
        # This is why, when training the generator, the target labels (valid) are set to 1. The generator wants to make its generated
        # images look as real as possible so that the discriminator will classify them as real.
        # If you set the ground truth labels to 0, the generator would be trying to generate images that the discriminator
        # can confidently classify as fake, which goes against the goal of training the generator. We want to train the generator
        # to improve its ability to deceive the discriminator, not make it easier for the discriminator to identify fake images.
        real = torch.ones(imgs.size(0), 1, device=Config.DEVICE)        
        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(generated_imgs), real)
        # calculate gradients for generator 
        self.manual_backward(g_loss)
        # update generator weights
        opt_g.step()                
        # train discriminator        
        # ground truth result (ie: all fake)
        real = torch.ones(imgs.size(0), 1, device=Config.DEVICE)
        fake = torch.zeros(imgs.size(0), 1, device=Config.DEVICE)
        opt_d.zero_grad()
        # adversarial loss is binary cross-entropy
        real_loss = self.adversarial_loss(self.discriminator(imgs), real)
        fake_loss = self.adversarial_loss(self.discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        # calculate gradients for discriminator
        self.manual_backward(d_loss)
        # update discriminator weights
        opt_d.step()
        self.log('d_loss', d_loss, prog_bar=True, on_epoch=True)
        self.log('g_loss', g_loss, prog_bar=True, on_epoch=True)
        return {'d_loss': d_loss, 'g_loss': g_loss}        

    def configure_optimizers(self):
        lr = self.lr
        b1 = self.b1
        b2 = self.b2
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d]

In [None]:
# Instantiate the GAN model
model = MnistGAN(Config.DIM_Z)
wandb_logger = None
if WandbConfig.USE_WANDB:
    wandb_logger = WandbLogger(project=WandbConfig.WANDB_PROJECT, name=WandbConfig.WANDB_RUN_NAME)    

# Generate fixed noise for validation
validation_z = torch.randn(8, model.dim_z)
gan_monitor = GANMonitorCallback(model=model, validation_z=validation_z, log_images=True)
# Initialize Trainer with W&B logger
trainer = pl.Trainer(
        max_epochs=Config.NUM_EPOCHS,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        logger=wandb_logger,
        enable_model_summary=True,
        precision=Config.PRECISION,
        enable_progress_bar=True,
        log_every_n_steps=Config.LOG_EVERY_N_STEPS,
        callbacks=[gan_monitor]
    )    
trainer.fit(model, train_dataloaders=train_loader)