# Generative Model

This notebook deals with finding a good generative model.

In [1]:
# Imports
import import_ipynb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchsummary import summary
from utils import * # type: ignore
import os
import polars as pl
import matplotlib.pyplot as plt
import umap.umap_ as umap
import warnings
from torchvision.transforms import v2
import torch.optim as optim
import math

Mean: tensor([0.1982])
Std: tensor([0.3426])


In [24]:
# Global variables

# Hyperparameters
BATCH_SIZE = 128
LEARNING_RATE = 0.001
NUM_EPOCHS = 120
PATIENCE = 50

NOISE_DIM = 100
GAN_LATENT_DIM = 100

## Conditional Variatonal Autoencoder

First of all I want to start with a CNN-based variatonal Autoencoder. This way I can apply the knowledge that I gained while building the Classifier Model earlier.

In order to put more focus on the class label I did embed the labels before passing them into the encoder, as well concatonating the image label to the decoder input.

Moreover I did try to use KL annealing where I linearly increase the weight of the KL Divergence in the loss function to focus on proper reconstruction first.

In [3]:
# No normalization needed for CVAE
vae_train_transforms = v2.Compose([
    # Include data augmentation, but exclude normalization
    *train_transforms.transforms[:-1],
])

vae_test_transforms = v2.Compose([
    # Include data augmentation, but exclude normalization
    *test_transforms.transforms[:-1],
])

# Create Train and Test Datasets
train_data = QuickDrawDataset('../dataset/train.csv', '../dataset/images', vae_train_transforms)
test_data = QuickDrawDataset('../dataset/test.csv', '../dataset/images', vae_test_transforms)

# Data loaders
TRAIN_LOADER = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=6)
TEST_LOADER = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=6)

In [4]:
# VAE loss function - combines reconstruction loss and KL divergence
def vae_loss(recon_x, x, mu, logvar, class_logits, labels, beta=1.0, gamma=1.0):
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    CLASS = F.cross_entropy(class_logits, labels, reduction='sum')
    return MSE, KLD, CLASS, (MSE + beta * KLD + gamma * CLASS)

In [14]:
# Define the Conditional Variational Autoencoder (CVAE) model
class CVAE(nn.Module):
    def __init__(self, latent_dim=150, num_classes=len(classes)):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        # label embedding to inject the information into the encoder
        self.label_embedding = nn.Embedding(num_classes, 8) 

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1 + 8, 128, kernel_size=3, padding=1),  # 1 channel image + 1 channel label map
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
        )

        # Latent mappings
        self.fc_input_dim = 256 * 7 * 7
        self.fc_mu = nn.Linear(self.fc_input_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.fc_input_dim, latent_dim)

        # Classifier head on latent space
        self.classifier = nn.Linear(latent_dim, num_classes)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim + num_classes, 256 * 7 * 7)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.Dropout(0.3),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),

            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.Dropout(0.3),

            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    # Reparameterization trick to sample from the latent space
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y):
        batch_size = x.size(0)

        # Embed image label 
        label_embed = self.label_embedding(y)
        label_map = label_embed.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 28, 28)
        x_cat = torch.cat([x, label_map], dim=1) 

        # Encode
        x_encoded = self.encoder(x_cat)

        # mu + log_var
        x_flat = x_encoded.view(batch_size, -1)
        mu = self.fc_mu(x_flat)
        logvar = self.fc_logvar(x_flat)
        
        z = self.reparameterize(mu, logvar)

        # Class logits from latent vector z
        class_logits = self.classifier(z)

        # Decode + add class label
        y_onehot = F.one_hot(y, self.num_classes).float().to(x.device)
        z_cat = torch.cat([z, y_onehot], dim=1)
        x_decoded = self.fc_decode(z_cat)
        x_decoded = x_decoded.view(batch_size, 256, 7, 7)
        recon_x = self.decoder(x_decoded)

        return recon_x, mu, logvar, z, class_logits

    def sample(self, z, y):
        y_onehot = F.one_hot(y, self.num_classes).float().to(z.device)
        z_cat = torch.cat([z, y_onehot], dim=1)
        x_decoded = self.fc_decode(z_cat)
        x_decoded = x_decoded.view(z.size(0), 256, 7, 7)
        samples = self.decoder(x_decoded)
        return samples

In [6]:
# Sigmoid like anealing
def kl_anneal_sigmoid(epoch, total_epochs, max_beta=1.0, k=0.1, x0=None):
    if x0 is None:
        # midpoint of annealing
        x0 = total_epochs / 2  
    beta = max_beta / (1 + math.exp(-k * (epoch - x0)))
    return beta


In [7]:
def sample_conditional_images(model, epoch, num_classes=5, latent_dim=50, device=DEVICE):
    with torch.no_grad():
        z = torch.randn(num_classes, latent_dim).to(device)
        y = torch.arange(num_classes).to(device)
        samples = model.sample(z, y).cpu()

        _, axes = plt.subplots(1, num_classes, figsize=(num_classes*2, 2))
        for i in range(num_classes):
            ax = axes[i]
            ax.imshow(samples[i].squeeze(), cmap='gray')
            ax.axis('off')
            ax.set_title(classes[i])

        plt.suptitle(f'Sampled Images at Epoch {epoch}', y=1.05)
        plt.subplots_adjust(top=0.8)
        plt.show()

In [8]:
# Plots a UMAP projection to visualize the latent space of the model
def plot_umap(z_all, y_all, epoch, class_names):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        reducer = umap.UMAP(n_components=2, random_state=None)
    z_2d = reducer.fit_transform(z_all)

    # Get unique classes sorted to align with colorbar ticks
    unique_classes = sorted(set(y_all))

    # Create a dict mapping class index to name
    class_labels = [class_names[c] for c in unique_classes]

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(z_2d[:, 0], z_2d[:, 1], c=y_all, cmap='tab10', alpha=0.6)

    # Setup colorbar with ticks and labels
    cbar = plt.colorbar(scatter, ticks=unique_classes)
    cbar.ax.set_yticklabels(class_labels)

    plt.title(f'UMAP Projection of Latent Space at Epoch {epoch}')
    plt.xlabel("UMAP dim 1")
    plt.ylabel("UMAP dim 2")
    plt.show()


In [17]:
# Training and evaluation loop for the CVAE model
kl = []
recon = []
claz = []
loss = []

# Beta parameter for KL divergence weighting
beta = 2
# Gamma parameter for classification loss weighting
gamma = 0.5

model = CVAE().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Params for early stopping
best_loss = float('inf')
epochs_no_improve = 0

# Create directories for saving weights if they don't exist
os.makedirs(os.path.dirname(f'../weights/generative/cvae.pt'), exist_ok=True)

model.train()
for epoch in range(NUM_EPOCHS):

    z_all = []
    y_all = []

    for i, (x, labels) in enumerate(TRAIN_LOADER):
        x = x.to(DEVICE)
        labels = labels.to(DEVICE)

        x_reconst, mu, log_var, z, class_logits = model(x,labels)
        # Set the midpoint quite early
        current_beta = kl_anneal_sigmoid(epoch, NUM_EPOCHS, beta, x0=25)
        reconst_loss, kl_div, class_loss, train_loss = vae_loss(x_reconst, x, mu, log_var, class_logits, labels, current_beta, gamma)

        recon.append(reconst_loss.item()/len(x))
        kl.append(kl_div.item()/len(x))
        loss.append(train_loss.item()/len(x))
        claz.append(class_loss.item()/len(x))

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        z_all.append(z.detach().cpu())
        y_all.append(labels.detach().cpu())

    # Early stopping logic
    if train_loss.item()/len(x) < best_loss:
        best_loss = train_loss
        epochs_no_improve = 0
        torch.save(model, f'../weights/generative/cvae.pt')
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping at epoch {epoch+1}")
        break

    print(f"Epoch {epoch+1} | Learning Rate {optimizer.param_groups[0]['lr']:.4f}\n"
        f"Loss: {train_loss.item()/len(x):.4f} | "
        f"Reconstruction Loss: {reconst_loss.item()/len(x):.4f} | "
        f"Beta: {current_beta:.4f} | "
        f"KL Divergence: {kl_div.item()/len(x):.4f} | "
        f"Classification Loss: {class_loss.item()/len(x):.4f}")


    if (epoch + 1) % 10 == 0:
        model.eval()
        if (epoch + 1) % 50 == 0:
            # Visualize the latent space using UMAP
            z_all = torch.cat(z_all)
            y_all = torch.cat(y_all)
            plot_umap(z_all.numpy(), y_all.numpy(), epoch + 1, classes)
        sample_conditional_images(model, epoch + 1, model.num_classes, model.latent_dim, DEVICE)
        model.train()

Epoch 1 | Learning Rate 0.0010
Loss: 296.6012 | Reconstruction Loss: 75.3498 | Beta: 0.1517 | KL Divergence: 1454.0845 | Classification Loss: 1.2860
Epoch 2 | Learning Rate 0.0010
Loss: nan | Reconstruction Loss: nan | Beta: 0.1663 | KL Divergence: nan | Classification Loss: nan
Epoch 3 | Learning Rate 0.0010
Loss: nan | Reconstruction Loss: nan | Beta: 0.1822 | KL Divergence: nan | Classification Loss: nan
Epoch 4 | Learning Rate 0.0010
Loss: nan | Reconstruction Loss: nan | Beta: 0.1995 | KL Divergence: nan | Classification Loss: nan
Epoch 5 | Learning Rate 0.0010
Loss: nan | Reconstruction Loss: nan | Beta: 0.2182 | KL Divergence: nan | Classification Loss: nan


Exception ignored in: <function _afterFork at 0x7f1c4923b060>
Traceback (most recent call last):
  File "/usr/lib/python3.13/logging/__init__.py", line 245, in _afterFork
    def _afterFork():
KeyboardInterrupt: 


RuntimeError: DataLoader worker (pid(s) 156653, 156654, 156655, 156656, 156657) exited unexpectedly

## Conditional GAN

Based on this [paper](https://arxiv.org/pdf/1511.06434) I want to build a conditional version that is able to create images

In [None]:
# The GAN expected images with values from [-1, 1] so we need different normalizations
cgan_transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.5], [0.5])
])

# Add data augmentations during training
cgan_train_transforms = v2.Compose([
    *train_transforms.transforms[:-3],
    *cgan_transforms.transforms
])

# cGAN Dataset and Dataloader
cgan_train_data = QuickDrawDataset('../dataset/train.csv', '../dataset/images', cgan_train_transforms)
cgan_test_data = QuickDrawDataset('../dataset/test.csv', '../dataset/images', cgan_transforms)

cgan_train_loader = DataLoader(cgan_train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=6, pin_memory=True)
cgan_test_loader = DataLoader(cgan_test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=6, pin_memory=True)
    
# The Generator
class ConditionalGenerator(nn.Module):
    def __init__(self, noise_dim=100, num_classes=5, embedding_dim=100):
        super(ConditionalGenerator, self).__init__()

        # Conditional Part: Embedd label and add concat with input
        self.label_embed = nn.Embedding(num_classes, embedding_dim)
        
        self.fc = nn.Sequential(
            nn.Linear(noise_dim * 2, 512 * 7 * 7),  # was 256*7*7
            nn.BatchNorm1d(512 * 7 * 7),
            nn.ReLU(True)
        )
        
        self.net = nn.Sequential(
            # Upsample 7x7 -> 14x14, reduce channels from 512 to 256
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),  
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # Upsample 14x14 -> 28x28, reduce channels from 256 to 128
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),  
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # Final conv to get single channel image 28x28, no stride/padding change
            nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1),  
            nn.Tanh()  # normalize to [-1,1]
        )

    def forward(self, noise, labels):
        label_embedding = self.label_embed(labels)
        x = torch.cat([noise, label_embedding], dim=1)
        x = self.fc(x).view(-1, 512, 7, 7)
        return self.net(x)

# The Discriminator
class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes=5, embedding_dim=100):
        super(ConditionalDiscriminator, self).__init__()

        # Conditional Part: Embedd label and add concat with input
        self.label_embed = nn.Embedding(num_classes, embedding_dim)
        self.label_fc = nn.Linear(embedding_dim, 28 * 28)

        self.net = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1, bias=False),  # 28x28 -> 14x14
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),  # 14x14 -> 7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Conv2d(128, 1, kernel_size=7, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_embedding = self.label_embed(labels)
        label_img = self.label_fc(label_embedding).view(-1, 1, 28, 28)
        x = torch.cat([img, label_img], dim=1)  # concat on channel dim
        return self.net(x)


In [None]:
def sample_images(generator, noise, image_labels, epoch):
    
    with torch.no_grad():
        generated_images = generator(noise, image_labels) 
        generated_images = generated_images * 0.5 + 0.5  # Unnormalize from [-1, 1] to [0, 1]

    # Plot
    # Get model predictions -> logits
    _, axs = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axs.flatten()):
        img = generated_images[i]
        img = img.squeeze(0).cpu()

        ax.imshow(img, cmap='gray')
        ax.set_title(f"Class {classes[image_labels[i].item()]}")
        ax.axis('off')
        
    plt.suptitle(f'Sampled Images at Epoch {epoch}')
    plt.show()

In [42]:
LEARNING_RATE = 0.0001

# Init models
G = ConditionalGenerator(noise_dim=NOISE_DIM, num_classes=len(classes)).to(DEVICE)
D = ConditionalDiscriminator(num_classes=len(classes)).to(DEVICE)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Pregenerated fixed nose and fixed labels to check if generation improves
fixed_noise = torch.randn(25, NOISE_DIM, device=DEVICE)
fixed_image_labels = torch.arange(0, len(classes)).repeat(5).to(DEVICE)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

G.apply(weights_init_normal)
D.apply(weights_init_normal)

for epoch in range(0, NUM_EPOCHS):
    for i, (real_imgs, image_labels) in enumerate(cgan_train_loader):
        real_imgs = real_imgs.to(DEVICE)
        image_labels = image_labels.to(DEVICE)

        # Add a bit of gaussian noice to images in order to prevent the discriminator from becoming to good
        # Decay the added noise over time
        noise_std = max(0.1 * (1.0 - epoch / NUM_EPOCHS), 0.01)
        instance_noise = torch.randn_like(real_imgs) * noise_std
        noisy_real_images = real_imgs + instance_noise

        # Train Discriminator
        optimizer_D.zero_grad()

        # Generate fake images
        z = torch.randn(BATCH_SIZE, GAN_LATENT_DIM, device=DEVICE)
        fake_image_labels = image_labels
        fake_images = G(z, fake_image_labels)

        # Add instance noise to fake images
        noisy_fake_images = fake_images + torch.randn_like(fake_images) * noise_std

        # Fake labels in [0.0, 0.25]
        fake_d_labels = torch.empty(BATCH_SIZE, device=DEVICE).uniform_(0.0, 0.25)
        
        # Generate real labels in [0.75, 1.0] - prevent discriminator  overconfidence
        real_d_labels = torch.empty(BATCH_SIZE, device=DEVICE).uniform_(0.75, 1.0)

        # Clamp noisy images to keep training stable
        noisy_real_images = torch.clamp(noisy_real_images, -1, 1)
        noisy_fake_images = torch.clamp(noisy_fake_images, -1, 1)

        # Shuffle real and fake images to prevent ordering bias
        all_images = torch.cat([noisy_real_images, noisy_fake_images], dim=0)
        all_image_labels = torch.cat([image_labels, fake_image_labels], dim=0)
        all_d_labels = torch.cat([real_d_labels, fake_d_labels], dim=0)

        perm = torch.randperm(all_images.size(0))
        shuffled_images = all_images[perm]
        shuffled_image_labels = all_image_labels[perm]
        shuffled_d_labels = all_d_labels[perm]

        outputs = D(shuffled_images, shuffled_image_labels)
        loss_D = criterion(outputs.view(-1), shuffled_d_labels)
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        # Generate fake images again
        fake_images = G(z, fake_image_labels)

        # Get discriminator's prediction on these fake images with their labels
        outputs = D(fake_images, fake_image_labels)

        # Generator wants the discriminator to believe these are real
        gen_targets = torch.empty(BATCH_SIZE, device=DEVICE).uniform_(0.75, 1.0)

        # Compute generator loss 
        loss_G = criterion(outputs.view(-1), gen_targets)
        
        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch {epoch+1} | Loss_D: {loss_D.item():.4f} | Loss_G: {loss_G.item():.4f}")

    # Save sample images
    if (epoch + 1) % 10 == 0:
        G.eval()
        sample_images(G, fixed_noise, fixed_image_labels, epoch+1)
        G.train()

Epoch 1 | Loss_D: 0.3681 | Loss_G: 0.3910
Epoch 2 | Loss_D: 0.3773 | Loss_G: 0.3603
Epoch 3 | Loss_D: 0.3688 | Loss_G: 0.3965
Epoch 4 | Loss_D: 0.3728 | Loss_G: 0.3880
Epoch 5 | Loss_D: 0.3767 | Loss_G: 0.3885
Epoch 6 | Loss_D: 0.3852 | Loss_G: 0.3556
Epoch 7 | Loss_D: 0.3713 | Loss_G: 0.3998
Epoch 8 | Loss_D: 0.3829 | Loss_G: 0.3848
Epoch 9 | Loss_D: 0.3826 | Loss_G: 0.3721


KeyboardInterrupt: 

## Non-conditional GAN

Maybe training the generator to produce output for every class is less efficient then just training five smaller models.

In [None]:
class_label

## Conditional Diffusion Model