In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm   
from torch import optim 
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm']

# Data loading
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=64, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=64, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

# Calculate training parameters
batch_size = 64
num_batches_per_epoch = len(train_loader.dataset) // batch_size
num_of_epochs = 50000 // num_batches_per_epoch
LATENT_DIM = 128
NUM_CLASSES = 100

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))
print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResBlock, self).__init__()
        self.conv1 = spectral_norm(nn.Conv2d(in_channels, in_channels, 3, padding=1))
        self.conv2 = spectral_norm(nn.Conv2d(in_channels, in_channels, 3, padding=1))
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(in_channels)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)

class CVAE(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, num_classes=NUM_CLASSES):
        super(CVAE, self).__init__()
        
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Encoder
        self.enc_conv1 = spectral_norm(nn.Conv2d(3, 64, 3, stride=2, padding=1))  # 16x16
        self.enc_res1 = ResBlock(64)
        self.enc_conv2 = spectral_norm(nn.Conv2d(64, 128, 3, stride=2, padding=1))  # 8x8
        self.enc_res2 = ResBlock(128)
        self.enc_conv3 = spectral_norm(nn.Conv2d(128, 256, 3, stride=2, padding=1))  # 4x4
        self.enc_res3 = ResBlock(256)
        
        # Class embedding
        self.class_embedding = nn.Embedding(num_classes, 512)
        
        # Fully connected layers
        self.enc_fc = nn.Linear(256 * 4 * 4 + 512, 1024)
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_var = nn.Linear(1024, latent_dim)
        
        # Decoder
        self.dec_fc1 = nn.Linear(latent_dim + 512, 1024)
        self.dec_fc2 = nn.Linear(1024, 256 * 4 * 4)
        
        self.dec_conv1 = spectral_norm(nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1))  # 8x8
        self.dec_res1 = ResBlock(128)
        self.dec_conv2 = spectral_norm(nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1))  # 16x16
        self.dec_res2 = ResBlock(64)
        self.dec_conv3 = spectral_norm(nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1))  # 32x32
        
    def encode(self, x):
        # Encode image
        x = F.relu(self.enc_conv1(x))
        x = self.enc_res1(x)
        x = F.relu(self.enc_conv2(x))
        x = self.enc_res2(x)
        x = F.relu(self.enc_conv3(x))
        x = self.enc_res3(x)
        x = x.view(-1, 256 * 4 * 4)
        
        # Process class label
        c_emb = self.class_embedding(c)
        
        # Combine image and class features
        x = torch.cat([x, c_emb], dim=1)
        x = F.relu(self.enc_fc(x))
        
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        # Process class label
        c_emb = self.class_embedding(c)
        
        # Combine latent and class features
        z = torch.cat([z, c_emb], dim=1)
        
        x = F.relu(self.dec_fc1(z))
        x = F.relu(self.dec_fc2(x))
        x = x.view(-1, 256, 4, 4)
        
        x = F.relu(self.dec_conv1(x))
        x = self.dec_res1(x)
        x = F.relu(self.dec_conv2(x))
        x = self.dec_res2(x)
        x = torch.tanh(self.dec_conv3(x))
        return x
    
    def forward(self, x):
        mu, log_var = self.encode(x, c)
        z = self.reparameterize(mu, log_var)
        return self.decode(z, c), mu, log_var

def loss_function(recon_x, x, mu, log_var, beta=4.0):
    # Reconstruction loss (MSE for images)
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    
    # KL divergence loss with increased weight (beta-VAE)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return MSE + beta * KLD

def show_images(images, labels=None, title="Generated Images"):
    plt.figure(figsize=(15, 15))
    plt.title(title)
    
    # Create a grid of subplots
    grid_size = int(np.ceil(np.sqrt(len(images))))
    
    for i, img in enumerate(images):
        plt.subplot(grid_size, grid_size, i + 1)
        # Move the image to CPU and convert to numpy
        if torch.is_tensor(img):
            img = img.detach().cpu()
        # Rearrange from (C,H,W) to (H,W,C)
        img = img.permute(1, 2, 0).numpy()
        # Clip values to valid range
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        if labels is not None:
            plt.title(class_names[labels[i]], fontsize=8)
        plt.axis('off')
    
    plt.tight_layout()
    disp.clear_output(wait=True)
    disp.display(plt.gcf())
    plt.close()

def generate_and_show_images(model, epoch, num_images=16):
    model.eval()
    
    with torch.no_grad():
        # Generate specific classes
        selected_classes = np.random.choice(NUM_CLASSES, num_images)
        class_labels = torch.tensor(selected_classes).to(device)
        
        # Sample from standard normal distribution
        z = torch.randn(num_images, LATENT_DIM).to(device)
        
        # Generate images
        samples = model.decode(z)
        # Denormalize
        samples = samples * 0.5 + 0.5
        
        # Show images with class labels
        show_images(samples, selected_classes, f'Generated Images - Epoch {epoch}')
    
    return samples


def train_model():
    model = VAE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    steps = 0
    epoch = 0
    
    while steps < 50000:
        model.train()
        train_loss = 0
        
        for batch_idx in range(num_batches_per_epoch):
            data, _ = next(train_iterator)
            data = data.to(device)
            
            optimizer.zero_grad()
            recon_batch, mu, log_var = model(data)
            loss = loss_function(recon_batch, data, mu, log_var)
            
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            
            steps += 1
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Step: {steps}, Loss: {loss.item() / len(data):.4f}')
                
            if steps >= 50000:
                break
        
        avg_loss = train_loss / len(train_loader.dataset)
        print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')

        generate_and_show_images(model, epoch)

        epoch += 1
    
    return model

def generate_images(model, num_images=16):
    model.eval()
    
    with torch.no_grad():
        # Sample from standard normal distribution
        z = torch.randn(num_images, LATENT_DIM).to(device)
        # Generate images
        sample = model.decode(z)
        # Denormalize
        sample = sample * 0.5 + 0.5
        
    return sample

if __name__ == "__main__":
    # Train the model
    model = train_model()
    
    # Generate some sample images
    samples = generate_images(model)
    
    # Save model
    torch.save(model.state_dict(), 'cifar100_vae.pth')