In [45]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [46]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
# Load Dataset (MNIST-like)
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root="data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [48]:
latent_dim = 10

In [49]:
# Define VAutoencoder Model
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_dim=28*28, latent_dim=64):
        super(VariationalAutoencoder, self).__init__()

        # Encoder
        # self.encoder = nn.Sequential() Can no longer build with sequential because we have two paths to train. One to mu, the other to sigma^2 (variance)

        # Shared encoder
        self.encoder_fc1 = nn.Linear(input_dim, 128)

        # Separate paths for mean and log variance
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_var = nn.Linear(128, latent_dim)

        # Decoder
        self.decode = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid()  # Output in range [0,1] for images
        )

    def encode(self, x):
        hidden_layer = F.relu(self.encoder_fc1(x))
        mu = self.fc_mu(hidden_layer)
        logvar = self.fc_var(hidden_layer) # actually the log variance
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar) # this is why it's the log variance e^(0.5*var)
        eps = torch.randn_like(std) # randomly generated from a standard normal distribution *like* (same dimensions) std
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z)
        return decoded, mu, logvar

In [50]:
def vae_loss(decoded_images, images, mu, logvar):
    recon_loss = nn.MSELoss()(decoded_images, images) #
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return recon_loss + KLD.mean() # think about what must be returned here for the loss.backward() step to work properly.

In [51]:
# def vae_loss(decoded_images, images, mu, logvar):
#     recon_loss = nn.MSELoss()(decoded_images, images) #
#     KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1)
#     return recon_loss + KLD # think about what must be returned here for the loss.backward() step to work properly.

In [52]:
import numpy as np

In [None]:
# Initialize Model
# There's something wrong in here that involves the loss function

vae = VariationalAutoencoder(latent_dim=latent_dim).to(device)
criterion = vae_loss
optimizer = optim.Adam(vae.parameters(), lr=0.001)

# Training Loop
num_epochs = 100
for epoch in range(num_epochs):
  for images, _ in dataloader:
    images = images.view(images.size(0), -1).to(device)  # Flatten images
    decoded_images, mu, logvar = vae(images)
    loss = criterion(decoded_images, images, mu, logvar)  # Compare reconstructed and original

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training complete!")

Epoch [1/100], Loss: 0.0639
Epoch [2/100], Loss: 0.0716
Epoch [3/100], Loss: 0.0651
Epoch [4/100], Loss: 0.0669
Epoch [5/100], Loss: 0.0633
Epoch [6/100], Loss: 0.0680
Epoch [7/100], Loss: 0.0607
Epoch [8/100], Loss: 0.0703
Epoch [9/100], Loss: 0.0649
Epoch [10/100], Loss: 0.0642
Epoch [11/100], Loss: 0.0672
Epoch [12/100], Loss: 0.0656
Epoch [13/100], Loss: 0.0661
Epoch [14/100], Loss: 0.0750
Epoch [15/100], Loss: 0.0685
Epoch [16/100], Loss: 0.0710
Epoch [17/100], Loss: 0.0669
Epoch [18/100], Loss: 0.0676
Epoch [19/100], Loss: 0.0687
Epoch [20/100], Loss: 0.0643
Epoch [21/100], Loss: 0.0648
Epoch [22/100], Loss: 0.0655
Epoch [23/100], Loss: 0.0655
Epoch [24/100], Loss: 0.0728
Epoch [25/100], Loss: 0.0709
Epoch [26/100], Loss: 0.0699
Epoch [27/100], Loss: 0.0686
Epoch [28/100], Loss: 0.0649
Epoch [29/100], Loss: 0.0699
Epoch [30/100], Loss: 0.0704
Epoch [31/100], Loss: 0.0660
Epoch [32/100], Loss: 0.0742
Epoch [33/100], Loss: 0.0697
Epoch [34/100], Loss: 0.0662
Epoch [35/100], Loss: 0