In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
from torch import optim;
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=32, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=32, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 32

num_batches_per_epoch = len(train_loader.dataset) // batch_size

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

In [None]:
# let's view some of the training data
plt.rcParams['figure.dpi'] = 100
x, t = next(train_iterator)

# Ensure the tensor is correctly moved to the GPU
x = x.to(device)
t = t.to(device)

# Plot the images
plt.imshow(torchvision.utils.make_grid(x).cpu().numpy().transpose(1, 2, 0))
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import numpy as np

class ResNetBlock(nn.Module):
    """ResNet-style block with skip connections and normalization"""
    def __init__(self, channels, dropout_rate=0.2):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        residual = x
        out = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.leaky_relu(out, 0.2)
        return out

class ResNetGenerator(nn.Module):
    """Generator using ResNet principles for high-quality image synthesis"""
    def __init__(self, latent_dim=100, output_channels=3):
        super().__init__()
        
        self.initial = nn.Sequential(
            nn.Linear(latent_dim, 512 * 4 * 4),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512 * 4 * 4)
        )
        
        self.main_path = nn.Sequential(
            # 4x4 -> 8x8
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            ResNetBlock(512),
            
            # 8x8 -> 16x16
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            ResNetBlock(512),
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            # 16x16 -> 32x32
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            ResNetBlock(256),
            nn.Conv2d(256, 128, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            # Final processing
            ResNetBlock(128),
            nn.Conv2d(128, output_channels, 1),
            nn.Tanh()
        )
    
    def forward(self, z):
        x = self.initial(z)
        x = x.view(-1, 512, 4, 4)
        return self.main_path(x)

class ResNetDiscriminator(nn.Module):
    """Discriminator using ResNet principles for robust classification"""
    def __init__(self, input_channels=3):
        super().__init__()
        
        self.main_path = nn.Sequential(
            # Initial convolution: 32x32 -> 32x32
            nn.Conv2d(input_channels, 64, 3, padding=1),
            nn.LeakyReLU(0.2),
            
            # First block: 32x32 -> 16x16
            ResNetBlock(64),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            # Second block: 16x16 -> 8x8
            ResNetBlock(128),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            # Third block: 8x8 -> 4x4
            ResNetBlock(256),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            # Final block
            ResNetBlock(512)
        )
        
        self.output = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        features = self.main_path(x)
        return self.output(features)

class ResNetGAN:
    """Complete GAN implementation with training and visualization"""
    def __init__(self, latent_dim=100, device='cuda'):
        self.device = device
        self.latent_dim = latent_dim
        
        # Initialize networks
        self.generator = ResNetGenerator(latent_dim).to(device)
        self.discriminator = ResNetDiscriminator().to(device)
        
        # Initialize optimizers
        self.g_optimizer = optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        # Initialize criterion
        self.criterion = nn.BCELoss()
        
        # Initialize logs
        self.g_losses = []
        self.d_losses = []
    
    def train_step(self, real_images):
        batch_size = real_images.size(0)
        real_label = torch.ones(batch_size, 1).to(self.device)
        fake_label = torch.zeros(batch_size, 1).to(self.device)
        
        # Train Discriminator
        self.d_optimizer.zero_grad()
        
        # Train on real images
        real_output = self.discriminator(real_images)
        d_loss_real = self.criterion(real_output, real_label)
        
        # Train on fake images
        z = torch.randn(batch_size, self.latent_dim).to(self.device)
        fake_images = self.generator(z)
        fake_output = self.discriminator(fake_images.detach())
        d_loss_fake = self.criterion(fake_output, fake_label)
        
        # Combine losses and update discriminator
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        self.d_optimizer.step()
        
        # Train Generator
        self.g_optimizer.zero_grad()
        
        # Generate new fake images
        z = torch.randn(batch_size, self.latent_dim).to(self.device)
        fake_images = self.generator(z)
        fake_output = self.discriminator(fake_images)
        
        # Calculate generator loss and update
        g_loss = self.criterion(fake_output, real_label)
        g_loss.backward()
        self.g_optimizer.step()
        
        return d_loss.item(), g_loss.item()
    
    def generate_samples(self, num_samples=16):
        """Generate and display sample images"""
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim).to(self.device)
            fake_images = self.generator(z)
            
            # Display the generated images
            plt.figure(figsize=(10, 10))
            grid_img = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
            plt.imshow(grid_img.cpu().permute(1, 2, 0))
            plt.axis('off')
            plt.show()
            plt.close()
    
    def visualize_training(self):
        """Plot the training losses"""
        plt.figure(figsize=(10, 5))
        plt.plot(self.g_losses, label='Generator Loss')
        plt.plot(self.d_losses, label='Discriminator Loss')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.show()
        plt.close()

# Training loop
def train_gan(gan, train_loader, num_epochs):
    print("Starting GAN training...")
    
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(train_loader):
            real_images = real_images.to(gan.device)
            
            # Train for one batch
            d_loss, g_loss = gan.train_step(real_images)
            
            # Store losses
            gan.d_losses.append(d_loss)
            gan.g_losses.append(g_loss)
            
        # Print progress and show samples
        if (epoch + 1) % 1 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')
            gan.generate_samples()
            gan.visualize_training()

# Usage example:
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize GAN
    gan = ResNetGAN(device=device)

    print("Parameters of generator: ", sum(p.numel() for p in gan.generator.parameters() if p.requires_grad))
    
    # Train GAN using your train_loader
    train_gan(gan, train_loader, num_of_epochs)