## Hierarchical VAE

This is a simple implementation of a Hierarchical VAE.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from datetime import datetime
import os
import matplotlib.pyplot as plt

## Download the dataset

In [2]:
def download_cifar10(data_path='./data'):
    """
    Download CIFAR-10 dataset and return trainset, testset, and classes

    Apply basic transformations to the data to normalize it between [-1, 1]
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_path,
        train=True,
        download=True,
        transform=transform
    )

    testset = torchvision.datasets.CIFAR10(
        root=data_path,
        train=False,
        download=True,
        transform=transform
    )

    print(f"Training set size: {len(trainset)}")
    print(f"Test set size: {len(testset)}")
    
    # CIFAR-10 classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
              'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainset, testset, classes

def get_dataloader(trainset, testset, batch_size=128):
    """Create DataLoader objects for training and testing"""
    train_loader = DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )
    
    test_loader = DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )
    
    return train_loader, test_loader

## Define the Hierarchical VAE


In [3]:
class Reparameterize(nn.Module):
    def forward(self, x):
        """
        Reparameterization trick to sample from the latent 
        distribution while allowing backpropagation
        """
        mu, log_var = x.chunk(2, dim=1)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z, mu, log_var
    
class HierarchicalVAE(nn.Module):
    def __init__(self, latent_dims=[512, 256, 128]):
        super().__init__()
        
        self.latent_dims = latent_dims  # [z1: 512, z2: 256, z3: 128]
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), # 16x16 -> 8x8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # 8x8 -> 4x4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, (latent_dims[0] * 2), 4), # 4x4 -> 1x1
            Reparameterize()
        )
        
        # 1x1 projections
        self.encoder2 = nn.Sequential(
            nn.Conv2d(latent_dims[0], latent_dims[1]*2, 1),
            nn.BatchNorm2d(latent_dims[1]*2),
            nn.LeakyReLU(),
            nn.Conv2d(latent_dims[1]*2, latent_dims[1]*2, 1),
            Reparameterize()
        )
        
        # 1x1 projections
        self.encoder3 = nn.Sequential(
            nn.Conv2d(latent_dims[1], latent_dims[2]*2, 1),
            nn.BatchNorm2d(latent_dims[2]*2),
            nn.LeakyReLU(),
            nn.Conv2d(latent_dims[2]*2, latent_dims[2]*2, 1),
            Reparameterize()
        )

        # Decoder path
        self.decoder3 = nn.Sequential(
            nn.Conv2d(latent_dims[2], latent_dims[1], 1),
            nn.BatchNorm2d(latent_dims[1]),
            nn.LeakyReLU(),
            nn.Conv2d(latent_dims[1], latent_dims[1], 1),
        )

        # Mixing output of decoder3 and encoder2
        self.mix2 = nn.Sequential(
            nn.Conv2d(latent_dims[1] * 2, latent_dims[1], 1),
            nn.BatchNorm2d(latent_dims[1]),
            nn.LeakyReLU(),
            nn.Conv2d(latent_dims[1], latent_dims[1], 1),
        )
        
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(latent_dims[1], latent_dims[0], 1),
            nn.BatchNorm2d(latent_dims[0]),
            nn.LeakyReLU(),
            nn.Conv2d(latent_dims[0], latent_dims[0], 1),
        )

        self.mix1 = nn.Sequential(
            nn.Conv2d(latent_dims[0] * 2, latent_dims[0], 1),
            nn.BatchNorm2d(latent_dims[0]),
            nn.LeakyReLU(),
            nn.Conv2d(latent_dims[0], latent_dims[0], 1)
        )

        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(latent_dims[0], 128, 4), # 1x1 -> 4x4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # [b, 64, 8, 8]
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # [b, 32, 16, 16]
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),   # [b, 3, 32, 32]
            nn.Tanh()
        )

    def encode(self, x):
        """Encodes the input into hierarchical latent variables."""
        z1, mu1, log_var1 = self.encoder1(x)
        z2, mu2, log_var2 = self.encoder2(z1)
        z3, mu3, log_var3 = self.encoder3(z2)
        
        return [z3, z2, z1], [(mu3, log_var3), 
                             (mu2, log_var2), 
                             (mu1, log_var1)]

    def decode(self, zs):
        z3, z2, z1 = zs

        h3 = self.decoder3(z3)  # [batch_size, 256]

        combined_z2 = self.mix2(torch.cat([z2, h3], dim=1))
        h2 = self.decoder2(combined_z2)  # [batch_size, 512]

        combined_z1 = self.mix1(torch.cat([z1, h2], dim=1))
        
        # Reconstruct the image
        h1 = self.decoder1(combined_z1)  # [B, 3, 32, 32]
        return h1

    def forward(self, x):
        """Forward pass through the HVAE."""
        zs, mu_vars = self.encode(x)
        recon_x = self.decode(zs)
        return recon_x, mu_vars, zs

    def random_samples(self, num_samples, device='cuda'):
        """Generates random samples from the HVAE using conditional sampling."""
        with torch.no_grad():
            # Sample z1 (lowest level) from prior
            z1 = self.z1_prior(num_samples)
            
            # Level 2
            h2 = self.encoder2(z1)  # [B, 128, 4, 4]
            z2_params = self.z2_proj(h2)
            z2_mu, z2_logvar = z2_params.chunk(2, dim=1)
            z2 = self.reparameterize(z2_mu, z2_logvar)
            
            # Level 3
            h3 = self.encoder3(z2 + torch.randn_like(z2) * 0.05)  # [B, 256, 2, 2]
            z3_params = self.z3_proj(h3)
            z3_mu, z3_logvar = z3_params.chunk(2, dim=1)
            z3 = self.reparameterize(z3_mu, z3_logvar)
            
            # Decode all latent variables
            samples = self.decode([z3, z2, z1])
            samples = (samples + 1) / 2
            return samples.cpu()

In [4]:
def hvae_loss_function(recon_x, x, mu_vars, beta=0.5, epoch=None, warmup_epochs=10):
    # Reconstruction loss:
    # p(x|z) = N(x; μ(z), σ²I)
    # log p(x|z) = -0.5 * (log(2πσ²) + (x - μ(z))²/σ²)
    # log p(x|z) ∝ -0.5 * Σ(x - μ(z))²
    # recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)
    recon_loss = F.l1_loss(recon_x, x, reduction='sum') / x.size(0)
    
    # KL divergence loss
    # KL(N(μ,σ²) || N(0,1)) = 0.5 * (μ² + σ² - ln(σ²) - 1)
    kl_losses = []
    kl_weights = [1.0, 1.0, 1.0] # z3, z2, z1
    for (mu, log_var), weight in zip(mu_vars, kl_weights):
        # kl_i = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        kl_i = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=[1,2,3]))
        kl_i = kl_i * weight
        kl_losses.append(kl_i)
    
    total_kl_loss = sum(kl_losses)
    

    # Apply KL annealing if epoch is provided:
    # This is useful because usually the recon_loss
    # overwhelms the optimizer and we end up in a posterior collapse
    # where the KL Divergence never decreases
    # This is a simple way to gradually increase the KL loss
    # and prevent posterior collapse
    if epoch is not None:
        # Linearly increase beta from 0 to its final value
        beta_weight = min(epoch / warmup_epochs, 1.0) * beta
    else:
        beta_weight = beta

    total_loss = recon_loss + beta_weight*total_kl_loss
    
    return total_loss, recon_loss, total_kl_loss, kl_losses

def save_image_samples(model, data, writer, epoch, device):
    """Save original and reconstructed images to tensorboard"""
    model.eval()
    with torch.no_grad():
        # Get reconstructions
        data = data.to(device)
        recon_batch, _, _ = model(data)
        
        data_cpu = data[:8].cpu()
        recon_cpu = recon_batch.cpu()[:8]
        comparison = torch.cat([
            data_cpu,
            recon_cpu
        ])
        
        # Add images to tensorboard
        writer.add_images('Original_Reconstructed', comparison, epoch)

def train_epoch(model, train_loader, vae_optimizer, device, writer, epoch):
    model.train()

    train_loss = 0
    train_recon_loss = 0
    train_kl_losses = [0, 0, 0]  # For each level
    train_kl_loss = 0
    n_samples = len(train_loader.dataset)

    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        
        recon_batch, mu_vars, zs = model(data)  # HVAE returns mu_vars list
        loss, recon_loss, total_kl, kl_losses = hvae_loss_function(
            recon_batch, data, mu_vars, epoch=epoch
        )

        # 1. Train VAE
        vae_optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += total_kl.item()
        for i, kl in enumerate(kl_losses):
            train_kl_losses[i] += kl.item()

        vae_optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')
    
    save_image_samples(model, data, writer, epoch, device)

    # Log Epoch Metrics
    avg_loss = train_loss / n_samples
    avg_recon_loss = train_recon_loss / n_samples
    avg_kl_loss = train_kl_loss / n_samples
    avg_kl_losses = [kl / n_samples for kl in train_kl_losses]
    writer.add_scalar('Loss/train/total', avg_loss, epoch)
    writer.add_scalar('Loss/train/reconstruction', avg_recon_loss, epoch)
    writer.add_scalar('Loss/train/kl_divergence', avg_kl_loss, epoch)
    
    # Log individual KL losses
    for i, kl in enumerate(avg_kl_losses):
        writer.add_scalar(f'Loss/train/kl_level_{i+1}', kl, epoch)
    
    return avg_loss

In [5]:

def train_vae(epochs=100, batch_size=128, learning_rate=1e-3, device="cuda"):
    # Get data
    trainset, testset, _ = download_cifar10()  # Using your existing function
    train_loader, test_loader = get_dataloader(trainset, testset, batch_size)
    
    # Initialize model, optimizer, and tensorboard
    model = HierarchicalVAE().to(device)

    vae_params = [p for name, p in model.named_parameters() if not name.startswith('z1_prior')]
    vae_optimizer = torch.optim.Adam(vae_params, lr=learning_rate)

    log_dir = f'runs_new/HVAE_CIFAR10_{datetime.now().strftime("%Y%m%d-%H%M%S")}'
    writer = SummaryWriter(log_dir)
    
    # Training loop
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, vae_optimizer, device, writer, epoch)
        
        # Save a checkpoint every 10 epochs
        if epoch % 10 == 0:
            if not os.path.exists(f'{log_dir}/models'):
                os.makedirs(f'{log_dir}/models')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': vae_optimizer.state_dict(),
                'loss': train_loss,
            }, f'{log_dir}/models/hvae_checkpoint_epoch_{epoch}.pt')
    
    writer.close()

In [6]:
train_vae(epochs=30, batch_size=128, learning_rate=1e-3, device="cuda")


Files already downloaded and verified
Files already downloaded and verified
Training set size: 50000
Test set size: 10000


In [7]:
def generate_and_display_samples(log_dir, epoch=100, num_samples=64, device='cuda'):
    """Generate and display random samples from the VAE decoder"""
    model = HierarchicalVAE().to(device)
    checkpoint = torch.load(f'{log_dir}/models/hvae_checkpoint_epoch_{str(epoch)}.pt', weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    with torch.no_grad():
        samples = model.random_samples(num_samples, device)
        
        # Create a grid of images
        fig, axes = plt.subplots(8, 8, figsize=(12, 12))
        for idx, ax in enumerate(axes.flat):
            # Convert from [C,H,W] to [H,W,C] format
            img = samples[idx].permute(1, 2, 0)
            
            ax.imshow(img)
            ax.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        return samples
    
def display_reconstructions(log_dir, epoch, test_loader, num_images=64, device='cuda'):
    """Display original test images and their reconstructions side by side"""
    model = HierarchicalVAE().to(device)
    checkpoint = torch.load(f'{log_dir}/models/hvae_checkpoint_epoch_{str(epoch)}.pt', weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    import numpy as np
    # Calculate grid dimensions for a square-ish layout
    grid_size = int(np.ceil(np.sqrt(num_images)))
    
    # Collect images from test loader
    test_images = []
    for batch in test_loader:
        if isinstance(batch, (list, tuple)):
            img = batch[0]
        else:
            img = batch
        test_images.append(img)
        if len(test_images) >= num_images:
            break
            
    # Stack collected images
    test_images = torch.cat(test_images[:num_images], dim=0)
    test_images = test_images.to(device)
    
    with torch.no_grad():
        # Get reconstructions
        recons, _, _ = model(test_images)
        
        # Denormalize images
        test_images = (test_images + 1) / 2
        recons = (recons + 1) / 2
        
        # Move to CPU
        test_images = test_images.cpu()
        recons = recons.cpu()
        
        # Create figure with square grid
        fig, axes = plt.subplots(2*grid_size, grid_size, figsize=(grid_size*2, 4*grid_size))
        
        # Plot original images in first half
        for idx in range(num_images):
            row = idx // grid_size
            col = idx % grid_size
            if idx < len(test_images):
                axes[row, col].imshow(test_images[idx].permute(1, 2, 0))
            axes[row, col].axis('off')
            
        # Plot reconstructions in second half
        for idx in range(num_images):
            row = idx // grid_size + grid_size  # offset by grid_size for second half
            col = idx % grid_size
            if idx < len(recons):
                axes[row, col].imshow(recons[idx].permute(1, 2, 0))
            axes[row, col].axis('off')
            
        # Add titles
        axes[0, grid_size//2].set_title('Original Images')
        axes[grid_size, grid_size//2].set_title('Reconstructions')
        
        plt.tight_layout()
        plt.show()
        return model

In [None]:
# Initialize a new model instance
log_dir = './runs/HVAE_CIFAR10_20241229-153141'
samples = generate_and_display_samples(log_dir, 30)

# Optionally save to disk
torchvision.utils.save_image(samples, 'vae_samples.png', nrow=8, normalize=True)

In [None]:
trainset, testset, _ = download_cifar10()  # Using your existing function
train_loader, test_loader = get_dataloader(trainset, testset, 1)
display_reconstructions(log_dir, 100, test_loader)

In [None]:
model = display_reconstructions(log_dir, 100, test_loader, 16)