In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from datetime import datetime
import os
import matplotlib.pyplot as plt

In [2]:
def download_cifar10(data_path='./data'):
    """
    Download CIFAR-10 dataset and return trainset, testset, and classes

    Apply basic transformations to the data to normalize it between [-1, 1]
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_path,
        train=True,
        download=True,
        transform=transform
    )

    testset = torchvision.datasets.CIFAR10(
        root=data_path,
        train=False,
        download=True,
        transform=transform
    )

    print(f"Training set size: {len(trainset)}")
    print(f"Test set size: {len(testset)}")
    
    # CIFAR-10 classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
              'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainset, testset, classes

def get_dataloader(trainset, testset, batch_size=128):
    """Create DataLoader objects for training and testing"""
    train_loader = DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )
    
    test_loader = DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )
    
    return train_loader, test_loader

In [3]:
class Generator(nn.Module):
    def __init__(self, noise_vector_len = 100):
        super(Generator, self).__init__()
        self.noise_vector_len = noise_vector_len
        self.model = nn.Sequential(
            nn.ConvTranspose2d(noise_vector_len, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, batch_size):
        random_samples = torch.randn((batch_size, self.noise_vector_len, 1, 1)).to("cuda")
        fake_samples = self.model(random_samples)
        return fake_samples
    
    def loss(self, adversary_output_on_fake):
        """
        Generator aims to maximize adversary output on fake samples
        (or minimize negative log of adversary output)
        """
        return F.binary_cross_entropy(
            adversary_output_on_fake,
            torch.ones_like(adversary_output_on_fake)
        )

class Adversary(nn.Module):
    def __init__(self):
        super(Adversary, self).__init__()
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)
    
    def loss(self, output_on_real, output_on_fake):
        """
        Adversary aims to:
        - Maximize output on real samples (close to 1)
        - Minimize output on fake samples (close to 0)
        """
        # Real samples should be classified as 1
        loss_on_real = F.binary_cross_entropy(
            output_on_real,
            torch.ones_like(output_on_real)
        )
        
        # Fake samples should be classified as 0
        loss_on_fake = F.binary_cross_entropy(
            output_on_fake,
            torch.zeros_like(output_on_fake)
        )
        
        # Total loss is the average of both
        return (loss_on_real + loss_on_fake) / 2

In [4]:
def save_image_samples(generator, writer, epoch):
    """Save original and reconstructed images to tensorboard"""
    generator.eval()
    with torch.no_grad():
        generated_batch = generator(batch_size=32)
        data_cpu = generated_batch[:16].cpu()
        writer.add_images('Generated', data_cpu, epoch)

def train_epoch(generator, adversary, train_loader, generator_optimizer, adversary_optimizer, device, writer, epoch):
    generator.train()
    adversary.train()

    epoch_g_loss = 0
    epoch_a_loss = 0
    n_samples = len(train_loader.dataset)

    for batch_idx, (data, _) in enumerate(train_loader):
        batch_size = data.size(0)
        real_images = data.to(device)

        # Train Adversary
        adversary_optimizer.zero_grad()
        real_predictions = adversary(real_images)
        
        fake_images = generator(batch_size)
        fake_predictions = adversary(fake_images.detach())
        a_loss = adversary.loss(real_predictions, fake_predictions)
        a_loss.backward()
        adversary_optimizer.step()

        # Train Generator
        generator_optimizer.zero_grad()
        fake_images = generator(batch_size)
        fake_predictions = adversary(fake_images)
        g_loss = generator.loss(fake_predictions)
        g_loss.backward()
        generator_optimizer.step()


        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}] Batch [{batch_idx}/{len(train_loader)}] '
                      f'A_Loss: {a_loss.item():.4f} G_Loss: {g_loss.item():.4f}')
        
        epoch_g_loss += g_loss.item()
        epoch_a_loss += a_loss.item()
    
    save_image_samples(generator, writer, epoch)  # Using the last batch
    

    avg_g_loss = epoch_g_loss / n_samples
    avg_a_loss = epoch_a_loss / n_samples
    
    writer.add_scalar('Loss/train/g_loss', avg_g_loss, epoch)
    writer.add_scalar('Loss/train/a_loss', avg_a_loss, epoch)
    return avg_g_loss, avg_a_loss

In [5]:
def train_gan(epochs=100, batch_size=128, learning_rate=1e-3, device="cuda"):
    # Get data
    trainset, testset, _ = download_cifar10()
    train_loader, test_loader = get_dataloader(trainset, testset, batch_size)
    
    generator = Generator().to(device)
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=3e-3)

    adversary = Adversary().to(device)
    adversary_optimizer = torch.optim.Adam(adversary.parameters(), lr=1e-3)

    log_dir = f'runs/VAE_CIFAR10_{datetime.now().strftime("%Y%m%d-%H%M%S")}'
    writer = SummaryWriter(log_dir)
    
    # Training loop
    for epoch in range(1, epochs + 1):
        g_loss, a_loss = train_epoch(generator, adversary, train_loader, generator_optimizer, adversary_optimizer, device, writer, epoch)
        
        # Save a checkpoint every 10 epochs
        if epoch % 10 == 0:
            if not os.path.exists(f'{log_dir}/models'):
                os.makedirs(f'{log_dir}/models')
            torch.save({
                'epoch': epoch,
                'g_model_state_dict': generator.state_dict(),
                'a_model_state_dict': adversary.state_dict(),
                'g_optimizer_state_dict': generator_optimizer.state_dict(),
                'a_optimizer_state_dict': adversary_optimizer.state_dict(),
                'g_loss': g_loss,
                'a_loss': a_loss,
            }, f'{log_dir}/models/vae_checkpoint_epoch_{epoch}.pt')
    
    writer.close()

In [None]:
train_gan(epochs=100, batch_size=128, learning_rate=3e-3, device="cuda")
