**Main imports**

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

**Import dataset**

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=64, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=64, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 64

num_batches_per_epoch = len(train_loader.dataset) // batch_size

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

**View some of the test dataset**

In [None]:
# let's view some of the training data
plt.rcParams['figure.dpi'] = 100
x, t = next(train_iterator)

# Ensure the tensor is correctly moved to the GPU
x = x.to(device)
t = t.to(device)

# Plot the images
plt.imshow(torchvision.utils.make_grid(x).cpu().numpy().transpose(1, 2, 0))
plt.show()

### Generator, Discriminator, and Training Loop

In [6]:
num_classes = 100
check_interval = 5

Repo: https://github.com/atapour/dl-pytorch/blob/main/Conditional_GAN_Example/Conditional_GAN_Example.ipynb

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from scipy import linalg
from torchvision.models.inception import inception_v3
from torchvision.utils import make_grid

class AttentionBlock(nn.Module):
    """Self-attention block for capturing global dependencies"""
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.query = nn.Conv2d(channels, channels // 8, 1)
        self.key = nn.Conv2d(channels, channels // 8, 1)
        self.value = nn.Conv2d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, H, W = x.size()
        
        # Generate query, key, value projections
        query = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, H * W)
        value = self.value(x).view(batch_size, -1, H * W)
        
        # Compute attention scores
        attention = torch.bmm(query, key)
        attention = F.softmax(attention, dim=-1)
        
        # Apply attention to values
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, H, W)
        
        return self.gamma * out + x

class EnhancedResBlock(nn.Module):
    """Enhanced residual block with grouped convolutions and squeeze-excitation"""
    def __init__(self, in_channels, out_channels, groups=8):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1, groups=groups)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=groups)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Squeeze-Excitation block
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels // 16, 1),
            nn.ReLU(),
            nn.Conv2d(out_channels // 16, out_channels, 1),
            nn.Sigmoid()
        )
        
        # Skip connection with 1x1 conv if needed
        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        
    def forward(self, x):
        identity = self.skip(x)
        
        out = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        out = self.bn2(self.conv2(out))
        
        # Apply squeeze-excitation
        se_weight = self.se(out)
        out = out * se_weight
        
        out += identity
        out = F.leaky_relu(out, 0.2)
        return out

class ImprovedVAE(nn.Module):
    def __init__(self, latent_dim=256):  # Increased latent dimension
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder with enhanced blocks and attention
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            EnhancedResBlock(64, 128),
            AttentionBlock(128),
            nn.AvgPool2d(2),  # 16x16
            
            EnhancedResBlock(128, 256),
            AttentionBlock(256),
            nn.AvgPool2d(2),  # 8x8
            
            EnhancedResBlock(256, 512),
            AttentionBlock(512),
            nn.AvgPool2d(2),  # 4x4
        )
        
        self.flatten_dim = 4 * 4 * 512
        
        # Latent space projections with increased capacity
        self.fc_mu = nn.Sequential(
            nn.Linear(self.flatten_dim, self.flatten_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(self.flatten_dim // 2, latent_dim)
        )
        
        self.fc_logvar = nn.Sequential(
            nn.Linear(self.flatten_dim, self.flatten_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(self.flatten_dim // 2, latent_dim)
        )
        
        # Decoder input
        self.decoder_input = nn.Sequential(
            nn.Linear(latent_dim, self.flatten_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(self.flatten_dim // 2, self.flatten_dim)
        )
        
        # Decoder with enhanced blocks and attention
        self.decoder = nn.Sequential(
            EnhancedResBlock(512, 512),
            AttentionBlock(512),
            nn.Upsample(scale_factor=2),  # 8x8
            
            EnhancedResBlock(512, 256),
            AttentionBlock(256),
            nn.Upsample(scale_factor=2),  # 16x16
            
            EnhancedResBlock(256, 128),
            AttentionBlock(128),
            nn.Upsample(scale_factor=2),  # 32x32
            
            EnhancedResBlock(128, 64),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(-1, self.flatten_dim)
        return self.fc_mu(x), self.fc_logvar(x)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, 512, 4, 4)
        return self.decoder(x)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    def generate(self, num_samples, device='cuda'):
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim).to(device)
            return self.decode(z)

def improved_vae_loss(recon_x, x, mu, logvar, kld_weight=0.0005):  # Reduced KLD weight
    """Enhanced VAE loss combining L1, MSE, and perceptual losses"""
    # Reconstruction loss (combination of L1 and MSE)
    mse_loss = F.mse_loss(recon_x, x, reduction='sum')
    l1_loss = F.l1_loss(recon_x, x, reduction='sum')
    recon_loss = 0.5 * (mse_loss + l1_loss)
    
    # KL divergence loss with reduced weight
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kld_weight * kld_loss

# Training setup with improved optimizer settings
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = ImprovedVAE().to(device)
# optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_of_epochs, eta_min=1e-6)

def train_epoch(model, optimizer, train_iterator, device, kld_weight):
    model.train()
    total_loss = 0
    
    images, _ = next(train_iterator)
    images = images.to(device)
    
    optimizer.zero_grad()
    recon_images, mu, logvar = model(images)
    
    loss = improved_vae_loss(recon_images, images, mu, logvar, kld_weight)
    loss.backward()
    
    # Gradient clipping to prevent exploding gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    return loss.item()

def evaluate(model, test_iterator, device, num_batches=10):
    """Evaluate the model on test data"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for _ in range(num_batches):
            images, _ = next(test_iterator)
            images = images.to(device)
            
            recon_images, mu, logvar = model(images)
            loss = improved_vae_loss(recon_images, images, mu, logvar)
            total_loss += loss.item()
    
    return total_loss / num_batches
def improved_vae_loss(recon_x, x, mu, logvar, kld_weight=0.0005):  # Reduced KLD weight
    """Enhanced VAE loss combining L1, MSE, and perceptual losses"""
    # Reconstruction loss (combination of L1 and MSE)
    mse_loss = F.mse_loss(recon_x, x, reduction='sum')
    l1_loss = F.l1_loss(recon_x, x, reduction='sum')
    recon_loss = 0.5 * (mse_loss + l1_loss)
    
    # KL divergence loss with reduced weight
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kld_weight * kld_loss

def visualize_results(model, images, epoch, device, directories):
    """
    Visualize and save model outputs including reconstructions and generated samples
    
    Args:
        model: The VAE model
        images: Batch of original images
        epoch: Current epoch number
        device: Device to run the model on
        directories: Dictionary containing output directory paths
    """
    model.eval()
    with torch.no_grad():
        # Generate reconstructions
        recon_images, _, _ = model(images)
        
        # Create and save reconstruction comparison
        comparison = torch.cat([images[:8], recon_images[:8]])
        torchvision.utils.save_image(
            comparison,
            os.path.join(directories['reconstructions'], f'reconstruction_epoch_{epoch:04d}.png'),
            normalize=True,
            nrow=8
        )
        
        # Generate and save new samples
        samples = model.generate(16, device)
        torchvision.utils.save_image(
            samples,
            os.path.join(directories['generated_images'], f'generated_epoch_{epoch:04d}.png'),
            normalize=True,
            nrow=4
        )

directories = {
    'generated_images': 'generated_images',  # For generated samples
    'reconstructions': 'reconstruction_images',  # For reconstruction comparisons
    'checkpoints': 'model_checkpoints',  # For saved model states
    'metrics': 'training_metrics'  # For loss plots and other metrics
}

for dir_name in directories.values():
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
        print(f"Created directory: {dir_name}")

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImprovedVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_of_epochs, eta_min=1e-6)

print("Parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

best_loss = float('inf')
train_losses = []
test_losses = []

# Training loop
print("Starting training...")
best_loss = float('inf')

for epoch in range(num_of_epochs):
    # Adjust KL weight using cyclic schedule
    kld_weight = 0.001 * (1 + np.sin(np.pi * epoch / 10))
    
    # Train for one epoch
    train_loss = train_epoch(model, optimizer, train_iterator, device, kld_weight)
    
    train_losses.append(train_loss)
    
    # Evaluate and visualize after each epoch
    test_loss = evaluate(model, test_iterator, device)
    test_losses.append(test_loss)
    
    # Get a batch of images for visualization
    images, _ = next(test_iterator)
    images = images.to(device)
    
    # Generate and save visualizations
    visualize_results(model, images, epoch, device, directories)
    
    # Print progress with more detailed information
    print(f'Epoch [{epoch:04d}/{num_of_epochs:04d}]')
    print(f'Training Loss: {train_loss:.6f}')
    print(f'Test Loss: {test_loss:.6f}')
    print(f'KLD Weight: {kld_weight:.6f}')
    print('-' * 50)
    
    # Save best model with epoch number in filename
    if test_loss < best_loss:
        best_loss = test_loss
        model_path = os.path.join(directories['checkpoints'], f'best_vae_epoch_{epoch:04d}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'test_loss': test_loss,
            'best_loss': best_loss
        }, model_path)
        print(f'Saved new best model at epoch {epoch} with test loss: {test_loss:.6f}')
    
    # Plot and save loss curves
    if epoch % 5 == 0:
        plt.figure(figsize=(12, 6))
        plt.plot(train_losses, label='Training Loss', color='blue', alpha=0.7)
        plt.plot(test_losses, label='Test Loss', color='red', alpha=0.7)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'Training Progress - Epoch {epoch}')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.savefig(os.path.join(directories['metrics'], f'loss_plot_epoch_{epoch:04d}.png'))
        plt.close()

print("Training complete!")

**Latent interpolations**

In [None]:
# now show some interpolations (note you do not have to do linear interpolations as shown here, you can do non-linear or gradient-based interpolation if you wish)
col_size = int(np.sqrt(batch_size))

z0 = z[0:col_size].repeat(col_size,1) # z for top row
z1 = z[batch_size-col_size:].repeat(col_size,1) # z for bottom row

t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).view(batch_size,1).to(device)

lerp_z = (1-t)*z0 + t*z1 # linearly interpolate between two points in the latent space
lerp_g = Generator.sample(lerp_z) # sample the model at the resulting interpolated latents

plt.rcParams['figure.dpi'] = 100
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(lerp_g).cpu().numpy().transpose(1, 2, 0), cmap=plt.cm.binary)
plt.show()

**FID scores**

Evaluate the FID from 10k of your model samples (do not sample more than this) and compare it against the 10k test images. Calculating FID is somewhat involved, so we use a library for it. It can take a few minutes to evaluate. Lower FID scores are better.

In [7]:
%%capture
!pip install clean-fid
import os
from cleanfid import fid
from torchvision.utils import save_image

In [9]:
# define directories
real_images_dir = 'real_images'
generated_images_dir = 'generated_images'
num_samples = 10000 # do not change

# create/clean the directories
def setup_directory(directory):
    if os.path.exists(directory):
        !rm -r {directory} # remove any existing (old) data
    os.makedirs(directory)

# setup_directory(real_images_dir)
# setup_directory(generated_images_dir)

# generate and save 10k model samples
num_generated = 0
while num_generated < num_samples:

    # sample from your model, you can modify this
    z = torch.randn(batch_size, latent_dim).to(device)
    samples_batch = N.sample(z).cpu().detach()

    for image in samples_batch:
        if num_generated >= num_samples:
            break
        save_image(image, os.path.join(generated_images_dir, f"gen_img_{num_generated}.png"))
        num_generated += 1

# save 10k images from the CIFAR-100 test dataset
num_saved_real = 0
while num_saved_real < num_samples:
    real_samples_batch, _ = next(test_iterator)
    for image in real_samples_batch:
        if num_saved_real >= num_samples:
            break
        save_image(image, os.path.join(real_images_dir, f"real_img_{num_saved_real}.png"))
        num_saved_real += 1

In [None]:
# compute FID
score = fid.compute_fid(real_images_dir, generated_images_dir, mode="clean")
print(f"FID score: {score}")