In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=64, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=64, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 64

num_batches_per_epoch = len(train_loader.dataset) // batch_size

print("Length of train_loader: ", len(train_loader.dataset))

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np

class Generator(nn.Module):
    def __init__(self, latent_dim=128):
        super(Generator, self).__init__()
        
        # Increased latent dimension and initial feature maps for better expressivity
        self.init_size = 4
        self.latent_dim = latent_dim
        self.init_channels = 512  # Increased from 64
        
        # Larger dense projection with batch norm for stable training
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, self.init_channels * self.init_size ** 2),
            nn.BatchNorm1d(self.init_channels * self.init_size ** 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Enhanced convolutional decoder with residual connections
        self.conv_blocks = nn.Sequential(
            # 4x4x512 -> 8x8x256
            nn.BatchNorm2d(self.init_channels),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.init_channels, self.init_channels//2, 3, 1, 1),
            ResBlock(self.init_channels//2),
            
            # 8x8x256 -> 16x16x128
            nn.BatchNorm2d(self.init_channels//2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.init_channels//2, self.init_channels//4, 3, 1, 1),
            ResBlock(self.init_channels//4),
            
            # 16x16x128 -> 32x32x64
            nn.BatchNorm2d(self.init_channels//4),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.init_channels//4, self.init_channels//8, 3, 1, 1),
            ResBlock(self.init_channels//8),
            
            # Final refinement layers
            nn.BatchNorm2d(self.init_channels//8),
            nn.Conv2d(self.init_channels//8, self.init_channels//16, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.init_channels//16, 3, 3, 1, 1),
            nn.Tanh()
        )
        
        # Initialize weights for better training dynamics
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
                
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], self.init_channels, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x):
        return self.activation(x + self.block(x))

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        
        # Increased initial channels for better feature extraction
        init_channels = 64  # Starting with more channels
        
        # Enhanced feature extraction with spectral normalization
        self.features = nn.Sequential(
            # 32x32x3 -> 16x16x64
            spectral_norm(nn.Conv2d(3, init_channels, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(init_channels),
            
            # 16x16x64 -> 8x8x128
            spectral_norm(nn.Conv2d(init_channels, init_channels*2, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(init_channels*2),
            
            # 8x8x128 -> 4x4x256
            spectral_norm(nn.Conv2d(init_channels*2, init_channels*4, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(init_channels*4),
            
            # 4x4x256 -> 4x4x512
            spectral_norm(nn.Conv2d(init_channels*4, init_channels*8, 3, 1, 1)),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # More expressive critic head
        self.critic_head = nn.Sequential(
            nn.Flatten(),
            spectral_norm(nn.Linear(init_channels*8 * 4 * 4, 1024)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Linear(1024, 1))
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
                
    def forward(self, img):
        features = self.features(img)
        validity = self.critic_head(features)
        return validity

class WGAN:
    def __init__(self, latent_dim=64, clip_value=0.01, n_critic=5, device='cuda', class_names=None):
        self.latent_dim = latent_dim
        self.clip_value = clip_value
        self.n_critic = n_critic
        self.device = device
        self.class_names = class_names
        
        # Initialize networks
        self.generator = Generator(latent_dim).to(device)
        self.critic = Critic().to(device)
        
        # Print model parameters
        generator_params = sum(p.numel() for p in self.generator.parameters())
        critic_params = sum(p.numel() for p in self.critic.parameters())
        print(f"Generator parameters: {generator_params:,}")
        print(f"Critic parameters: {critic_params:,}")
        print(f"Total parameters: {generator_params + critic_params:,}")
        
        # Initialize optimizers
        # self.optimizer_G = optim.RMSprop(self.generator.parameters(), lr=0.0005)
        # self.optimizer_C = optim.RMSprop(self.critic.parameters(), lr=0.0005)
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_C = optim.Adam(self.critic.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        self.steps = 0
        self.G_losses = []
        self.C_losses = []

    def train_step(self, real_imgs):
        batch_size = real_imgs.shape[0]
        real_imgs = real_imgs.to(self.device)
        
        # Train Critic
        self.optimizer_C.zero_grad()
        z = torch.randn(batch_size, self.latent_dim).to(self.device)
        fake_imgs = self.generator(z).detach()
        
        real_validity = self.critic(real_imgs)
        fake_validity = self.critic(fake_imgs)
        critic_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
        
        critic_loss.backward()
        self.optimizer_C.step()
        
        for p in self.critic.parameters():
            p.data.clamp_(-self.clip_value, self.clip_value)
        
        self.C_losses.append(critic_loss.item())
        
        # Train Generator
        if self.steps % self.n_critic == 0:
            self.optimizer_G.zero_grad()
            gen_imgs = self.generator(z)
            gen_validity = self.critic(gen_imgs)
            generator_loss = -torch.mean(gen_validity)
            
            generator_loss.backward()
            self.optimizer_G.step()
            
            self.G_losses.append(generator_loss.item())
            return generator_loss.item(), critic_loss.item()
        
        self.steps += 1
        return None, critic_loss.item()

    def generate_samples(self, n_samples):
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, self.latent_dim).to(self.device)
            samples = self.generator(z)
        self.generator.train()
        return samples

def train_wgan(num_batches_per_epoch, num_of_epochs, train_iterator, device='cuda', class_names=None):
    import os
    save_dir = 'generated_images'
    os.makedirs(save_dir, exist_ok=True)
    
    wgan = WGAN(latent_dim=128, clip_value=0.01, n_critic=5, device=device, class_names=class_names)
    
    total_steps = num_batches_per_epoch * num_of_epochs

    print("Total number of steps: ", total_steps)
    
    for step in range(total_steps):
        real_imgs, _ = next(train_iterator)
        g_loss, c_loss = wgan.train_step(real_imgs)
        
        if step % 100 == 0:
            epoch = step // num_batches_per_epoch
            batch = step % num_batches_per_epoch
            print(f"[Epoch {epoch}/{num_of_epochs}] "
                  f"[Batch {batch}/{num_batches_per_epoch}] "
                  f"[C loss: {c_loss:.4f}] "
                  + (f"[G loss: {g_loss:.4f}]" if g_loss is not None else ""))
            
            if step % 500 == 0:
                samples = wgan.generate_samples(16)
                save_path = os.path.join(save_dir, f'wgan_step_{step}.png')
                torchvision.utils.save_image(samples,
                                           save_path,
                                           normalize=True,
                                           nrow=4)
                print(f"Saved generated samples to {save_path}")
    
    return wgan
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wgan = train_wgan(num_batches_per_epoch, num_of_epochs, train_iterator, device)