In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=32, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=32, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 32

num_batches_per_epoch = len(train_loader.dataset) // batch_size

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        
        # Initial size before upsampling: 4x4
        self.init_size = 4
        self.latent_dim = latent_dim
        
        # Linear layer to reshape noise into initial feature maps
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 256 * self.init_size ** 2)
        )
        
        # Complex convolutional structure for generating high-quality color images
        self.conv_blocks = nn.Sequential(
            # First block: 4x4 -> 8x8
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Second block: 8x8 -> 16x16
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Third block: 16x16 -> 32x32
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Final layer: output 3 channels for RGB
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        # Transform noise into initial feature space
        out = self.l1(z)
        # Reshape into feature maps
        out = out.view(out.shape[0], 256, self.init_size, self.init_size)
        # Generate image through convolutional blocks
        img = self.conv_blocks(out)
        return img

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        
        # Convolutional feature extraction
        self.conv_blocks = nn.Sequential(
            # Input layer: 32x32x3
            nn.Conv2d(3, 64, 4, stride=2, padding=1),  # -> 16x16x64
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # -> 8x8x128
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128),
            
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # -> 4x4x256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # -> 2x2x512
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(512),
        )
        
        # Fully connected layer for final decision
        self.fc = nn.Sequential(
            nn.Linear(512 * 2 * 2, 1)
        )

    def forward(self, img):
        features = self.conv_blocks(img)
        features = features.view(features.shape[0], -1)
        validity = self.fc(features)
        return validity

class WGAN:
    def __init__(self, latent_dim=100, clip_value=0.01, n_critic=5, device='cuda'):
        self.latent_dim = latent_dim
        self.clip_value = clip_value
        self.n_critic = n_critic
        self.device = device
        
        # Initialize generator and critic
        self.generator = Generator(latent_dim).to(device)
        self.critic = Critic().to(device)
        
        # Initialize optimizers with learning rates from WGAN paper
        self.optimizer_G = optim.RMSprop(self.generator.parameters(), lr=0.00005)
        self.optimizer_C = optim.RMSprop(self.critic.parameters(), lr=0.00005)
        
        # Initialize training step counter
        self.steps = 0
        
        # Lists to store losses for plotting
        self.G_losses = []
        self.C_losses = []

    def train_step(self, real_imgs):
        batch_size = real_imgs.shape[0]
        real_imgs = real_imgs.to(self.device)
        
        # ---------------------
        #  Train Critic
        # ---------------------
        self.optimizer_C.zero_grad()
        
        # Sample noise for generator input
        z = torch.randn(batch_size, self.latent_dim).to(self.device)
        
        # Generate fake images
        fake_imgs = self.generator(z).detach()
        
        # Compute Wasserstein distance
        real_validity = self.critic(real_imgs)
        fake_validity = self.critic(fake_imgs)
        
        # Critic loss
        critic_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
        
        # Update critic
        critic_loss.backward()
        self.optimizer_C.step()
        
        # Clip critic weights
        for p in self.critic.parameters():
            p.data.clamp_(-self.clip_value, self.clip_value)
        
        # Store critic loss
        self.C_losses.append(critic_loss.item())
        
        # ---------------------
        #  Train Generator
        # ---------------------
        if self.steps % self.n_critic == 0:
            self.optimizer_G.zero_grad()
            
            # Generate images
            gen_imgs = self.generator(z)
            
            # Generator loss
            gen_validity = self.critic(gen_imgs)
            generator_loss = -torch.mean(gen_validity)
            
            # Update generator
            generator_loss.backward()
            self.optimizer_G.step()
            
            # Store generator loss
            self.G_losses.append(generator_loss.item())
            
            return generator_loss.item(), critic_loss.item()
        
        self.steps += 1
        return None, critic_loss.item()

    def generate_samples(self, n_samples):
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, self.latent_dim).to(self.device)
            samples = self.generator(z)
        self.generator.train()
        return samples

# Training loop using the provided iterator structure
def train_wgan(num_batches_per_epoch, num_of_epochs, train_iterator, device='cuda'):
    # Initialize WGAN
    wgan = WGAN(latent_dim=100, clip_value=0.01, n_critic=5, device=device)

    print("params: ", sum(p.numel() for p in wgan.generator.parameters()))
    
    total_steps = num_batches_per_epoch * num_of_epochs
    
    for step in range(total_steps):
        # Get batch of real images
        real_imgs, _ = next(train_iterator)
        
        # Train for one step
        g_loss, c_loss = wgan.train_step(real_imgs)
        
        # Print progress
        if step % 100 == 0:
            epoch = step // num_batches_per_epoch
            batch = step % num_batches_per_epoch
            print(f"[Epoch {epoch}/{num_of_epochs}] "
                  f"[Batch {batch}/{num_batches_per_epoch}] "
                  f"[C loss: {c_loss:.4f}] "
                  + (f"[G loss: {g_loss:.4f}]" if g_loss is not None else ""))
            
            # Generate and save sample images
            if step % 500 == 0:
                samples = wgan.generate_samples(16)
                torchvision.utils.save_image(samples,
                                           f'generated_images/wgan_step_{step}.png',
                                           normalize=True,
                                           nrow=4)
    
    return wgan


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wgan = train_wgan(num_batches_per_epoch, num_of_epochs, train_iterator, device)