In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=64, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=64, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 64

num_batches_per_epoch = len(train_loader.dataset) // batch_size

print("Length of train_loader: ", len(train_loader.dataset))

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
from torch.nn import functional as F
from typing import List, Tuple, Optional
import logging
import os

class CompactResBlock(nn.Module):
    def __init__(self, channels: int):
        super(CompactResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 1),
            nn.BatchNorm2d(channels)
        )
        self.activation = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.activation(x + self.block(x))

class Generator(nn.Module):
    def __init__(self, latent_dim: int = 128):
        super(Generator, self).__init__()
        
        # Architecture dimensions
        self.latent_dim = latent_dim
        self.init_size = 4  # Starting spatial size
        self.init_channels = 512  # Initial number of channels
        
        # Project and reshape
        self.project = nn.Sequential(
            nn.Linear(latent_dim, self.init_channels * self.init_size * self.init_size),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Main convolutional sequence
        self.main = nn.Sequential(
            # 4x4x512 -> 8x8x256
            nn.BatchNorm2d(self.init_channels),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.init_channels, self.init_channels // 2, 3, 1, 1),
            CompactResBlock(self.init_channels // 2),
            
            # 8x8x256 -> 16x16x128
            nn.BatchNorm2d(self.init_channels // 2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.init_channels // 2, self.init_channels // 4, 3, 1, 1),
            CompactResBlock(self.init_channels // 4),
            
            # 16x16x128 -> 32x32x64
            nn.BatchNorm2d(self.init_channels // 4),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.init_channels // 4, self.init_channels // 8, 3, 1, 1),
            CompactResBlock(self.init_channels // 8),
            
            # Final layer to get 3 channels
            nn.Conv2d(self.init_channels // 8, 3, 3, 1, 1),
            nn.Tanh()
        )
        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # Project and reshape
        out = self.project(z)
        out = out.view(-1, self.init_channels, self.init_size, self.init_size)
        # Apply main convolutional sequence
        return self.main(out)

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        
        # Initial number of channels in critic
        init_channels = 64
        
        # Main convolutional sequence
        self.main = nn.Sequential(
            # 32x32x3 -> 16x16x64
            spectral_norm(nn.Conv2d(3, init_channels, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            CompactResBlock(init_channels),
            
            # 16x16x64 -> 8x8x128
            spectral_norm(nn.Conv2d(init_channels, init_channels * 2, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            CompactResBlock(init_channels * 2),
            
            # 8x8x128 -> 4x4x256
            spectral_norm(nn.Conv2d(init_channels * 2, init_channels * 4, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            CompactResBlock(init_channels * 4),
        )
        
        # Critic head
        self.critic_head = nn.Sequential(
            nn.Flatten(),
            spectral_norm(nn.Linear(init_channels * 4 * 4 * 4, 1))
        )
        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        features = self.main(img)
        return self.critic_head(features)

class WGAN:
    def __init__(
        self,
        latent_dim: int = 128,
        lr: float = 2e-4,
        betas: Tuple[float, float] = (0.5, 0.999),
        n_critic: int = 3,
        gp_weight: float = 10.0,
        device: str = 'cuda',
        class_names: Optional[List[str]] = None
    ):
        self.latent_dim = latent_dim
        self.n_critic = n_critic
        self.gp_weight = gp_weight
        self.device = device
        self.class_names = class_names
        
        # Initialize networks
        self.generator = Generator(latent_dim).to(device)
        self.critic = Critic().to(device)
        
        # Print model parameters for verification
        generator_params = sum(p.numel() for p in self.generator.parameters())
        critic_params = sum(p.numel() for p in self.critic.parameters())
        print(f"Generator parameters: {generator_params:,}")
        print(f"Critic parameters: {critic_params:,}")
        
        # Optimizers
        self.optimizer_G = optim.Adam(
            self.generator.parameters(),
            lr=lr,
            betas=betas
        )
        self.optimizer_C = optim.Adam(
            self.critic.parameters(),
            lr=lr,
            betas=betas
        )
        
        # Training state
        self.steps = 0
        self.G_losses = []
        self.C_losses = []

    def compute_gradient_penalty(self, real_samples: torch.Tensor, fake_samples: torch.Tensor) -> torch.Tensor:
        batch_size = real_samples.size(0)
        # Use uniform sampling instead of random for better coverage
        alpha = torch.rand(batch_size, 1, 1, 1, device=self.device, requires_grad=True)
        
        # More numerically stable interpolation
        interpolates = real_samples + alpha * (fake_samples - real_samples)
        
        d_interpolates = self.critic(interpolates)
        
        # Use ones_like instead of manual creation for better clarity and stability
        fake = torch.ones_like(d_interpolates, requires_grad=False)
        
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        # L2 norm calculation with epsilon for numerical stability
        gradients = gradients.view(batch_size, -1)
        gradient_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
        gradient_penalty = ((gradient_norm - 1) ** 2).mean()
        
        return gradient_penalty

    def train_step(self, real_imgs: torch.Tensor) -> Tuple[Optional[float], float]:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(self.device)
        
        # Train Critic
        self.optimizer_C.zero_grad()
        
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_imgs = self.generator(z).detach()
        
        real_validity = self.critic(real_imgs)
        fake_validity = self.critic(fake_imgs)
        
        gradient_penalty = self.compute_gradient_penalty(real_imgs, fake_imgs)
        
        critic_loss = torch.mean(fake_validity) - torch.mean(real_validity) + \
                     self.gp_weight * gradient_penalty
        
        critic_loss.backward()

        self.optimizer_C.step()
        
        self.C_losses.append(critic_loss.item())
        
        # Train Generator
        generator_loss = None
        if self.steps % self.n_critic == 0:
            self.optimizer_G.zero_grad()
            
            gen_imgs = self.generator(z)
            gen_validity = self.critic(gen_imgs)
            
            generator_loss = -torch.mean(gen_validity)
            generator_loss.backward()
            
            self.optimizer_G.step()
            
            self.G_losses.append(generator_loss.item())
        
        self.steps += 1

        return generator_loss.item() if generator_loss is not None else None, critic_loss.item()

    @torch.no_grad()
    def generate_samples(self, n_samples: int) -> torch.Tensor:
        """Generate samples with the current generator"""
        self.generator.eval()
        z = torch.randn(n_samples, self.latent_dim, device=self.device)
        samples = self.generator(z)
        self.generator.train()
        return samples

def train_wgan(
    train_iterator,
    num_batches_per_epoch: int,
    num_of_epochs: int,
    device: str = 'cuda',
    class_names: Optional[List[str]] = None
) -> WGAN:
    """Training function for the WGAN"""
    save_dir = 'generated_images'
    os.makedirs(save_dir, exist_ok=True)

    wgan = WGAN(
        latent_dim=128,
        lr=2e-4,
        n_critic=3,
        device=device,
        class_names=class_names
    )


    total_steps = num_batches_per_epoch * num_of_epochs
    
    print(f"Training for {total_steps} total steps across {num_of_epochs} epochs")
    
    for epoch in range(num_of_epochs):
        for batch in range(num_batches_per_epoch):
            real_imgs, _ = next(train_iterator)
            
            g_loss, c_loss = wgan.train_step(real_imgs)
            current_step = epoch * num_batches_per_epoch + batch
            
            if current_step % 100 == 0:
                print(f"[Epoch {epoch}/{num_of_epochs}] "
                      f"[Batch {batch}/{num_batches_per_epoch}] "
                      f"[Step {current_step}/{total_steps}] "
                      f"[C loss: {c_loss:.4f}] "
                      + (f"[G loss: {g_loss:.4f}]" if g_loss is not None else ""))
                
            if current_step % 500 == 0:
                samples = wgan.generate_samples(16)
                save_path = os.path.join(save_dir, f'wgan_epoch_{epoch}_step_{current_step}.png')
                torchvision.utils.save_image(
                    samples,
                    save_path,
                    normalize=True,
                    nrow=4
                )
                print(f"Saved generated samples to {save_path}")
            
            if current_step >= 50000:
                print("Reached maximum number of steps (50,000). Stopping training.")
                return wgan
    
    return wgan

# First, confirm your training parameters
print(f"Training parameters:")
print(f"Batch size: {batch_size}")
print(f"Batches per epoch: {num_batches_per_epoch}")
print(f"Number of epochs: {num_of_epochs}")

# Initialize the WGAN training process
wgan = train_wgan(
    train_iterator=train_iterator,  # Your cycling iterator
    num_batches_per_epoch=num_batches_per_epoch,  # From your calculation
    num_of_epochs=num_of_epochs,  # From your calculation
    device=device,  # Your cuda device
    class_names=class_names  # Your CIFAR-100 class names
)

# Generate a batch of 16 samples
samples = wgan.generate_samples(16)

# Save the samples
save_path = 'generated_samples.png'
torchvision.utils.save_image(
    samples,
    save_path,
    normalize=True,
    nrow=4
)