In [6]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=32, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=32, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 32

num_batches_per_epoch = len(train_loader.dataset) // batch_size

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

cuda
Files already downloaded and verified
Files already downloaded and verified
> Size of training dataset 50000
> Size of test dataset 10000
Number of classes:  100
Number of batches per epoch:  1562
Number of epochs:  32


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np

class Generator(nn.Module):
    def __init__(self, latent_dim=64):
        super(Generator, self).__init__()
        
        # Reduced initial feature maps and latent dimension
        self.init_size = 4
        self.latent_dim = latent_dim
        
        # Smaller dense layer (64 -> 4x4x64 = 1,024 parameters)
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 64 * self.init_size ** 2)
        )
        
        # Efficient convolutional decoder
        self.conv_blocks = nn.Sequential(
            # 4x4x64 -> 8x8x64
            nn.BatchNorm2d(64),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),  # 64*64*3*3 + 64 = 36,928 params
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 8x8x64 -> 16x16x32
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),  # 64*32*3*3 + 32 = 18,464 params
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16x16x32 -> 32x32x16
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),  # 32*16*3*3 + 16 = 4,624 params
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32x32x16 -> 32x32x3
            nn.Conv2d(16, 3, 3, stride=1, padding=1),  # 16*3*3*3 + 3 = 435 params
            nn.Tanh()
        )
        
        # Total generator parameters: ~62,000

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 64, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        
        # Efficient feature extraction
        self.features = nn.Sequential(
            # 32x32x3 -> 16x16x16
            nn.Conv2d(3, 16, 4, stride=2, padding=1),  # 3*16*4*4 + 16 = 784 params
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            
            # 16x16x16 -> 8x8x32
            nn.Conv2d(16, 32, 4, stride=2, padding=1),  # 16*32*4*4 + 32 = 8,224 params
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            
            # 8x8x32 -> 4x4x64
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 32*64*4*4 + 64 = 32,832 params
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
        )
        
        # Efficient critic head
        self.critic_head = nn.Sequential(
            nn.Flatten(),  # 4x4x64 = 1024 features
            nn.Linear(1024, 1)  # 1024*1 + 1 = 1,025 params
        )
        
        # Total critic parameters: ~43,000

    def forward(self, img):
        features = self.features(img)
        validity = self.critic_head(features)
        return validity

class WGAN:
    def __init__(self, latent_dim=64, clip_value=0.01, n_critic=5, device='cuda', class_names=None):
        self.latent_dim = latent_dim
        self.clip_value = clip_value
        self.n_critic = n_critic
        self.device = device
        self.class_names = class_names
        
        # Initialize networks
        self.generator = Generator(latent_dim).to(device)
        self.critic = Critic().to(device)
        
        # Print model parameters
        generator_params = sum(p.numel() for p in self.generator.parameters())
        critic_params = sum(p.numel() for p in self.critic.parameters())
        print(f"Generator parameters: {generator_params:,}")
        print(f"Critic parameters: {critic_params:,}")
        print(f"Total parameters: {generator_params + critic_params:,}")
        
        # Initialize optimizers
        self.optimizer_G = optim.RMSprop(self.generator.parameters(), lr=0.00005)
        self.optimizer_C = optim.RMSprop(self.critic.parameters(), lr=0.00005)
        
        self.steps = 0
        self.G_losses = []
        self.C_losses = []

    def train_step(self, real_imgs):
        batch_size = real_imgs.shape[0]
        real_imgs = real_imgs.to(self.device)
        
        # Train Critic
        self.optimizer_C.zero_grad()
        z = torch.randn(batch_size, self.latent_dim).to(self.device)
        fake_imgs = self.generator(z).detach()
        
        real_validity = self.critic(real_imgs)
        fake_validity = self.critic(fake_imgs)
        critic_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
        
        critic_loss.backward()
        self.optimizer_C.step()
        
        for p in self.critic.parameters():
            p.data.clamp_(-self.clip_value, self.clip_value)
        
        self.C_losses.append(critic_loss.item())
        
        # Train Generator
        if self.steps % self.n_critic == 0:
            self.optimizer_G.zero_grad()
            gen_imgs = self.generator(z)
            gen_validity = self.critic(gen_imgs)
            generator_loss = -torch.mean(gen_validity)
            
            generator_loss.backward()
            self.optimizer_G.step()
            
            self.G_losses.append(generator_loss.item())
            return generator_loss.item(), critic_loss.item()
        
        self.steps += 1
        return None, critic_loss.item()

    def generate_samples(self, n_samples):
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, self.latent_dim).to(self.device)
            samples = self.generator(z)
        self.generator.train()
        return samples

def train_wgan(num_batches_per_epoch, num_of_epochs, train_iterator, device='cuda', class_names=None):
    import os
    save_dir = 'generated_images'
    os.makedirs(save_dir, exist_ok=True)
    
    wgan = WGAN(latent_dim=64, clip_value=0.01, n_critic=5, device=device, class_names=class_names)
    
    total_steps = num_batches_per_epoch * num_of_epochs
    
    for step in range(total_steps):
        real_imgs, _ = next(train_iterator)
        g_loss, c_loss = wgan.train_step(real_imgs)
        
        if step % 100 == 0:
            epoch = step // num_batches_per_epoch
            batch = step % num_batches_per_epoch
            print(f"[Epoch {epoch}/{num_of_epochs}] "
                  f"[Batch {batch}/{num_batches_per_epoch}] "
                  f"[C loss: {c_loss:.4f}] "
                  + (f"[G loss: {g_loss:.4f}]" if g_loss is not None else ""))
            
            if step % 500 == 0:
                samples = wgan.generate_samples(16)
                save_path = os.path.join(save_dir, f'wgan_step_{step}.png')
                torchvision.utils.save_image(samples,
                                           save_path,
                                           normalize=True,
                                           nrow=4)
                print(f"Saved generated samples to {save_path}")
    
    return wgan
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wgan = train_wgan(num_batches_per_epoch, num_of_epochs, train_iterator, device)

Generator parameters: 127,363
Critic parameters: 42,865
Total parameters: 170,228
[Epoch 0/32] [Batch 0/1562] [C loss: 0.0074] [G loss: 0.0024]
Saved generated samples to generated_images\wgan_step_0.png
[Epoch 0/32] [Batch 100/1562] [C loss: -0.0002] [G loss: 0.0011]
[Epoch 0/32] [Batch 200/1562] [C loss: 0.0003] [G loss: 0.0005]
[Epoch 0/32] [Batch 300/1562] [C loss: -0.0007] [G loss: 0.0012]
[Epoch 0/32] [Batch 400/1562] [C loss: 0.0002] [G loss: 0.0018]
[Epoch 0/32] [Batch 500/1562] [C loss: 0.0002] [G loss: 0.0019]
Saved generated samples to generated_images\wgan_step_500.png
[Epoch 0/32] [Batch 600/1562] [C loss: -0.0001] [G loss: 0.0009]
[Epoch 0/32] [Batch 700/1562] [C loss: 0.0001] [G loss: 0.0013]
[Epoch 0/32] [Batch 800/1562] [C loss: -0.0001] [G loss: 0.0005]
[Epoch 0/32] [Batch 900/1562] [C loss: 0.0001] [G loss: 0.0023]
[Epoch 0/32] [Batch 1000/1562] [C loss: -0.0002] [G loss: 0.0004]
Saved generated samples to generated_images\wgan_step_1000.png
[Epoch 0/32] [Batch 1100/