In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=32, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=32, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 32

num_batches_per_epoch = len(train_loader.dataset) // batch_size

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

In [None]:
# Set random seed for reproducibility
from tqdm import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


torch.manual_seed(42)

# Hyperparameters
latent_dim = 100
image_channels = 3
feature_dim = 64
num_epochs = 100
batch_size = 64
critic_iterations = 5
lambda_gp = 10
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Initial dense layer to reshape noise vector
        self.fc = nn.Linear(latent_dim, 4 * 4 * feature_dim * 8)
        
        # Main convolutional architecture
        self.main = nn.Sequential(
            # State size: (feature_dim*8) x 4 x 4
            nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_dim * 4),
            nn.ReLU(True),
            
            # State size: (feature_dim*4) x 8 x 8
            nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_dim * 2),
            nn.ReLU(True),
            
            # State size: (feature_dim*2) x 16 x 16
            nn.ConvTranspose2d(feature_dim * 2, feature_dim, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_dim),
            nn.ReLU(True),
            
            # Final layer to generate images
            # State size: feature_dim x 32 x 32
            nn.ConvTranspose2d(feature_dim, image_channels, 1, 1, 0, bias=False),
            nn.Tanh()
            # Output size: 3 x 32 x 32 (CIFAR-100 size)
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, feature_dim * 8, 4, 4)
        return self.main(x)

# Critic Network (Discriminator)
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        
        self.main = nn.Sequential(
            # Input size: 3 x 32 x 32
            nn.Conv2d(image_channels, feature_dim, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: feature_dim x 16 x 16
            nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(feature_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: (feature_dim*2) x 8 x 8
            nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(feature_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: (feature_dim*4) x 4 x 4
            nn.Conv2d(feature_dim * 4, feature_dim * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(feature_dim * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: (feature_dim*8) x 2 x 2
            nn.Conv2d(feature_dim * 8, 1, 2, 1, 0, bias=False)
            # Output: 1 x 1 x 1
        )

    def forward(self, x):
        return self.main(x).view(-1)

# Gradient Penalty
def compute_gradient_penalty(critic, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand((real_samples.size(0), 1, 1, 1)).to(device)
    
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    
    # Calculate critic scores
    d_interpolates = critic(interpolates)
    
    # Get gradient w.r.t. interpolates
    fake = torch.ones(real_samples.size(0)).to(device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty



# Training loop
def save_generated_images(generator, critic, epoch, device, save_dir='generated_images', num_images=16):
    """Save a grid of generated images with generator outputs and critic scores"""
    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Generate images
    z = torch.randn(num_images, latent_dim).to(device)
    generator.eval()
    critic.eval()
    
    with torch.no_grad():
        # Generate fake images
        fake_imgs = generator(z)
        # Get critic scores
        critic_scores = critic(fake_imgs)
        
        # Move to CPU and denormalize
        fake_imgs = fake_imgs.cpu()
        fake_imgs = fake_imgs * 0.5 + 0.5  # Denormalize
        critic_scores = critic_scores.cpu().numpy()

    # Create figure with subplots
    fig, axs = plt.subplots(4, 4, figsize=(15, 15))
    fig.suptitle(f'Generated Images (Epoch {epoch})', fontsize=16)
    
    # Convert images to grid
    grid = make_grid(fake_imgs, nrow=4, normalize=False)
    grid = grid.permute(1, 2, 0)  # Convert from CxHxW to HxWxC
    
    # Plot each image with its critic score
    for idx, ax in enumerate(axs.flat):
        if idx < num_images:
            # Extract individual image from grid
            img_size = grid.shape[0] // 4
            row = idx // 4
            col = idx % 4
            img = grid[row*img_size:(row+1)*img_size, 
                      col*img_size:(col+1)*img_size, :]
            
            # Display image
            ax.imshow(img)
            ax.axis('off')
            
            # Add critic score as title
            critic_score = critic_scores[idx]
            ax.set_title(f'{class_names[idx]}\nCritic Score:{critic_score:.3f}', fontsize=10)

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'fake_images_epoch_{epoch}.png'))
    plt.close()

    generator.train()
    critic.train()

def train():
    total_iterations = num_epochs * num_batches_per_epoch
    
    for iteration in tqdm(range(total_iterations)):
        # Calculate current epoch for logging
        current_epoch = iteration // num_batches_per_epoch
        
        # Get next batch using our infinite iterator
        real_imgs, _ = next(train_iterator)
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
            
        # Train Critic
        for _ in range(critic_iterations):
            c_optimizer.zero_grad()
            
            # Generate fake images
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)
            
            # Critic loss
            real_validity = critic(real_imgs)
            fake_validity = critic(fake_imgs.detach())
            
            # Gradient penalty
            gp = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data)
            
            # Wasserstein distance
            c_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
            
            c_loss.backward()
            c_optimizer.step()
        
        # Train Generator
        g_optimizer.zero_grad()
        
        # Generate fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        
        # Generator loss
        fake_validity = critic(fake_imgs)
        g_loss = -torch.mean(fake_validity)
        
        g_loss.backward()
        g_optimizer.step()
        
        # Print progress
        if iteration % 100 == 0:
            print(
                f"[Epoch {current_epoch}/{num_epochs}] "
                f"[C loss: {c_loss.item():.4f}] "
                f"[G loss: {g_loss.item():.4f}]"
            )
        
        # Save generated images every 5 epochs
        if current_epoch > 0 and current_epoch % 5 == 0 and iteration % num_batches_per_epoch == 0:
            save_generated_images(generator, critic, current_epoch, device)

# Initialize networks and optimizers
generator = Generator().to(device)

print("Parameters: ", sum(p.numel() for p in generator.parameters() if p.requires_grad))

critic = Critic().to(device)

g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
c_optimizer = optim.Adam(critic.parameters(), lr=lr, betas=(beta1, beta2))

train()