# ConvVAE on Fashion-MNIST
This notebook trains a Convolutional Variational Autoencoder (ConvVAE) on the Fashion-MNIST dataset, plots training/test loss curves, shows reconstructions, and generates new samples.

In [None]:
import os
# Allow duplicate OpenMP libs
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np


In [None]:
class ConvVAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        # Encoder
        self.enc1 = nn.Conv2d(1, 32, 4, 2, 1)
        self.enc2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.enc3 = nn.Conv2d(64, 128, 4, 2, 1)
        self.fc_mu  = nn.Linear(128*3*3, latent_dim)
        self.fc_log = nn.Linear(128*3*3, latent_dim)
        # Decoder
        self.fc_dec = nn.Linear(latent_dim, 128*3*3)
        self.dec1 = nn.ConvTranspose2d(128, 64, 4, 2, 1, output_padding=1)
        self.dec2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.dec3 = nn.ConvTranspose2d(32, 1, 4, 2, 1)

    def encode(self, x):
        h = F.relu(self.enc1(x))
        h = F.relu(self.enc2(h))
        h = F.relu(self.enc3(h))
        h = h.view(h.size(0), -1)
        return self.fc_mu(h), self.fc_log(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc_dec(z))
        h = h.view(-1, 128, 3, 3)
        h = F.relu(self.dec1(h))
        h = F.relu(self.dec2(h))
        return torch.sigmoid(self.dec3(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [None]:
# Data loading
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.FashionMNIST("data/fashion-mnist", train=True, download=True, transform=transform)
test_ds  = datasets.FashionMNIST("data/fashion-mnist", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False)


In [None]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvVAE(latent_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
epochs = 20
train_losses = []
test_losses = []


In [None]:
# Training loop
for epoch in range(1, epochs+1):
    model.train()
    train_loss = 0
    for data, _ in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        loss = loss_function(recon, data, mu, logvar)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss / len(train_loader.dataset))
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, logvar = model(data)
            test_loss += loss_function(recon, data, mu, logvar).item()
    test_losses.append(test_loss / len(test_loader.dataset))
    print(f"Epoch {epoch}: Train loss {train_losses[-1]:.2f}, Test loss {test_losses[-1]:.2f}")


In [None]:
# Plot losses
import matplotlib.pyplot as plt
plt.plot(train_losses, label="Train")
plt.plot(test_losses, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()


In [None]:
# Visualize reconstructions
model.eval()
import matplotlib.pyplot as plt
data, _ = next(iter(test_loader))
data = data.to(device)
with torch.no_grad(): recon, _, _ = model(data)
orig = data[:8]
recon = recon[:8]
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i in range(8):
    axes[0, i].imshow(orig[i].cpu().squeeze(), cmap="gray"); axes[0, i].axis("off")
    axes[1, i].imshow(recon[i].cpu().squeeze(), cmap="gray"); axes[1, i].axis("off")
plt.show()


In [None]:
# Generate new samples
model.eval()
import math
z = torch.randn(16, 32).to(device)
with torch.no_grad(): samples = model.decode(z)
grid = torch.cat([samples[i].cpu().unsqueeze(0) for i in range(16)], dim=0)
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
grid_img = make_grid(grid, nrow=4)
plt.figure(figsize=(6,6))
plt.imshow(grid_img.permute(1,2,0).squeeze(), cmap="gray")
plt.axis("off")
plt.show()
