In [None]:
# Cell 0 - Put all these together:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
latent_dim = 20
batch_size = 128
epochs = 10
lr = 1e-3

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True
)


In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 4, stride=2, padding=1)  # -> 32x14x14
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=1) # -> 64x7x7
        self.fc_mu = nn.Linear(64*7*7, latent_dim)
        self.fc_logvar = nn.Linear(64*7*7, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        return self.fc_mu(x), self.fc_logvar(x)

In [None]:
# Decoder Network
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 64*7*7)
        self.deconv1 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1) # -> 32x14x14
        self.deconv2 = nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1)  # -> 1x28x28

    def forward(self, z):
        x = F.relu(self.fc(z)).view(-1, 64, 7, 7)
        x = F.relu(self.deconv1(x))
        return torch.sigmoid(self.deconv2(x))

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

def vae_loss(x, x_recon, mu, logvar):
    # Reconstruction loss
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    # KL divergence
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

# Initialize model
vae = VAE(latent_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=lr)

# Training loop
vae.train()
for epoch in range(epochs):
    total_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        x_recon, mu, logvar = vae(x)
        loss = vae_loss(x, x_recon, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader.dataset):.4f}")

# Evaluation: generate images
vae.eval()
with torch.no_grad():
    z = torch.randn(64, latent_dim).to(device)
    sample = vae.decoder(z).cpu()
    grid = np.transpose(torchvision.utils.make_grid(sample, nrow=8, pad_value=1).numpy(), (1, 2, 0))
    plt.figure(figsize=(6,6))
    plt.title("Generated Digits")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()