In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Simple VAE architecture
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(784, 400),
            nn.ReLU(),
            nn.Linear(400, 200),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(200, latent_dim)
        self.fc_var = nn.Linear(200, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 200),
            nn.ReLU(),
            nn.Linear(200, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std
        
    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decoder(z), mu, log_var

# Fun Fact 1: Let's visualize how VAE learns a continuous latent space
def plot_latent_space(vae, loader, num_batches=100):
    plt.figure(figsize=(10, 10))
    with torch.no_grad():
        for i, (data, labels) in enumerate(loader):
            if i >= num_batches:
                break
            mu, _ = vae.encode(data.view(-1, 784))
            plt.scatter(mu[:, 0], mu[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar()
    plt.title("Latent Space Distribution")
    plt.show()

# Fun Fact 2: Let's see how VAE interpolates between digits
def interpolate_digits(vae, loader, start_digit=1, end_digit=7):
    with torch.no_grad():
        # Get examples of start and end digits
        start_img = None
        end_img = None
        for data, labels in loader:
            if start_img is None and start_digit in labels:
                start_img = data[labels == start_digit][0]
            if end_img is None and end_digit in labels:
                end_img = data[labels == end_digit][0]
            if start_img is not None and end_img is not None:
                break
        
        # Get latent representations
        start_mu, _ = vae.encode(start_img.view(-1, 784))
        end_mu, _ = vae.encode(end_img.view(-1, 784))
        
        # Create interpolations
        steps = 10
        interpolations = []
        for alpha in torch.linspace(0, 1, steps):
            z = start_mu * (1-alpha) + end_mu * alpha
            interpolated = vae.decoder(z)
            interpolations.append(interpolated.view(28, 28))
            
        # Plot
        plt.figure(figsize=(15, 3))
        for i in range(steps):
            plt.subplot(1, steps, i+1)
            plt.imshow(interpolations[i].cpu(), cmap='gray')
            plt.axis('off')
        plt.suptitle(f"Interpolation between {start_digit} and {end_digit}")
        plt.show()

# Train VAE
vae = VAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

print("Training VAE for 5 epochs to demonstrate concepts...")
for epoch in range(5):
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1} completed')

print("\nFun Fact #1: VAE learns a continuous 2D latent space where similar digits cluster together")
plot_latent_space(vae, train_loader)

print("\nFun Fact #2: VAE can smoothly interpolate between different digits")
interpolate_digits(vae, train_loader, start_digit=1, end_digit=7)

print("\nFun Fact #3: The loss function has two parts:")
print("1. Reconstruction loss (BCE) - Makes the autoencoder reconstruct inputs well")
print("2. KL Divergence loss - Forces the latent space to approximate a normal distribution")
