# MNIST Digit Generation with Conditional VAE
Train a Conditional Variational Autoencoder to generate handwritten digits

In [None]:
# Install required packages
!pip install torch torchvision matplotlib tqdm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Hyperparameters
batch_size = 128
learning_rate = 1e-3
num_epochs = 50
latent_dim = 20
num_classes = 10

In [None]:
# Load MNIST dataset
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print(f"Dataset size: {len(train_dataset)}")

In [None]:
# CVAE Model Definition
class ConditionalVAE(nn.Module):
    def __init__(self, latent_dim=20, num_classes=10):
        super(ConditionalVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
        )
        
        self.conv_output_size = 128 * 3 * 3
        
        self.encoder_fc = nn.Sequential(
            nn.Linear(self.conv_output_size + num_classes, 256),
            nn.ReLU(),
        )
        
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        
        # Decoder
        self.decoder_fc = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.ReLU(),
            nn.Linear(256, self.conv_output_size),
            nn.ReLU(),
        )
        
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Sigmoid(),
        )
        
    def encode(self, x, c):
        x = self.encoder_conv(x)
        x = x.view(x.size(0), -1)
        x = torch.cat([x, c], dim=1)
        x = self.encoder_fc(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, c):
        z = torch.cat([z, c], dim=1)
        x = self.decoder_fc(z)
        x = x.view(x.size(0), 128, 3, 3)
        x = self.decoder_conv(x)
        return x
    
    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, c)
        return recon_x, mu, logvar

In [None]:
# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# Initialize model and optimizer
model = ConditionalVAE(latent_dim=latent_dim, num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

In [None]:
# Training loop
model.train()
train_losses = []

for epoch in range(num_epochs):
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for batch_idx, (data, labels) in enumerate(progress_bar):
        data = data.to(device)
        labels = labels.to(device)
        
        labels_onehot = F.one_hot(labels, num_classes=num_classes).float()
        
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = model(data, labels_onehot)
        loss = loss_function(recon_batch, data, mu, logvar)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'Loss': loss.item() / len(data)})
    
    avg_loss = total_loss / len(train_loader.dataset)
    train_losses.append(avg_loss)
    print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
    
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'cvae_epoch_{epoch+1}.pth')

In [None]:
# Save final model
torch.save(model.state_dict(), 'cvae_final.pth')
print("Model saved as 'cvae_final.pth'")

In [None]:
# Generate and visualize samples
def generate_digits(model, digit, num_samples=5):
    model.eval()
    with torch.no_grad():
        labels = torch.tensor([digit] * num_samples).to(device)
        labels_onehot = F.one_hot(labels, num_classes=num_classes).float()
        z = torch.randn(num_samples, latent_dim).to(device)
        generated = model.decode(z, labels_onehot)
    return generated.cpu()

# Test generation for all digits
fig, axes = plt.subplots(10, 5, figsize=(15, 30))

for digit in range(10):
    generated = generate_digits(model, digit, 5)
    for i in range(5):
        axes[digit, i].imshow(generated[i].squeeze(), cmap='gray')
        axes[digit, i].set_title(f'Digit {digit} - Sample {i+1}')
        axes[digit, i].axis('off')

plt.tight_layout()
plt.savefig('all_generated_digits.png', dpi=150, bbox_inches='tight')
plt.show()