In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Hyperparameters
latent_dim = 2
batch_size = 128
epochs = 10
learning_rate = 1e-3

In [None]:
# Data loading
transform = transforms.Compose([transforms.ToTensor()])
train_loader = torch.utils.data.DataLoader(datasets.MNIST("./data", train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST("./data", train=False, download=True, transform=transform), batch_size=batch_size, shuffle=True)

In [None]:
# Define the Variational Autoencoder
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(28*28, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        self.fc2 = nn.Linear(latent_dim, 400)
        self.fc3 = nn.Linear(400, 28*28)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 28*28))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
def loss_function(recon_x, x, mu, logvar,recon,KL):
    loss = 0
    if recon == "BCE":
        loss = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    elif recon == "MSE":
        loss = F.mse_loss(recon_x, x.view(-1, 28*28), reduction='sum')
    KL_divergence = 0
    if KL:
        KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return loss + KL_divergence

# Training loop
def train(recon = "BCE",KL = True):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        # print("shape of data",data.shape)
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar,recon,KL)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(train_loader.dataset)

# Testing function
def test(recon = "BCE",KL = True):
    model.eval()
    test_loss = 0
    all_mu = []
    all_labels = []
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar,recon,KL).item()
            all_mu.append(mu.cpu().numpy())
            all_labels.append(labels.numpy())
    return test_loss / len(test_loader.dataset), np.vstack(all_mu), np.hstack(all_labels)




In [None]:
# Initialize model, optimizer, and loss function

# Loss function with BCE and KL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train and visualize latent space
for epoch in range(epochs):
    train_loss = train()
    test_loss, latent_means, labels = test()
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

# Plot latent space visualization
plt.figure(figsize=(8, 6))
plt.scatter(latent_means[:, 0], latent_means[:, 1], c=labels, cmap='viridis', alpha=0.7)
plt.colorbar()
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Visualization of Means With BCE and KL")
plt.show()

# **Observations on Latent Space Visualization of VAE (BCE + KL)**

## **1. Well-Structured Latent Space**
- The latent space has a meaningful clustering of digits, with different digit classes (color-coded) forming distinct but overlapping regions.
- This indicates that the VAE has learned a structured latent representation, where similar digits are placed closer together.

## **2. Overlapping and Separation of Classes**
- Some digit classes (e.g., 1 and 9) are more separated, while others (e.g., 3, 5, and 8) overlap significantly.
- This overlap suggests that the VAE finds certain digits more similar in feature space, possibly due to shared stroke patterns.

## **3. Gaussian-Like Distribution**
- The latent space appears roughly Gaussian, centered around (0,0) with spread-out clusters. This is a good sign since the **KL divergence loss forces the latent space to follow a normal distribution**.
- Some points extend outward, which might indicate that certain digits have more variation in their representations.


In [None]:
# Generate samples from a 2D Gaussian Grid and visualize reconstructions
grid_x = np.linspace(-3, 3, 10)
grid_y = np.linspace(-3, 3, 10)

grid_z = torch.tensor(np.array([[x, y] for x in grid_x for y in grid_y]), dtype=torch.float32).to(device)
with torch.no_grad():
    reconstructions = model.decode(grid_z).cpu().numpy()

fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(reconstructions[i].reshape(28, 28), cmap='gray')
    ax.axis('off')
plt.suptitle("Reconstructed Images from 2D Gaussian Grid Samples")
plt.show()

## Observations from the avove grid
- The images are not great but close enough to the images in the MNIST dataset
- There are few images which mixec 4 and 9 , 3 and 8 since they are kind of similar and the overlap in the latent space

In [None]:
# Loss function with no BCE but with KL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train and visualize latent space
for epoch in range(epochs):
    train_loss = train("")
    test_loss, latent_means, labels = test("")
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

# Plot latent space visualization
plt.figure(figsize=(8, 6))
plt.scatter(latent_means[:, 0], latent_means[:, 1], c=labels, cmap='viridis', alpha=0.7)
plt.colorbar()
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Visualization of Means Only with KL")
plt.show()

## **KL Divergence Loss**
### **Purpose:**
- Regularizes the latent space by forcing it to resemble a standard normal distribution (\(\mathcal{N}(0, I)\)).
- Ensures continuity and smoothness, allowing for meaningful interpolation between latent points.

### **Effects:**
- **High KL loss** → The model's latent space is too constrained, leading to blurry outputs.
- **Low KL loss** → The model learns useful latent representations.
- If KL loss dominates, the latent space collapses to a single point (**posterior collapse**).

### **Observations from KL-Only Results:**
- **Latent Space:** Completely collapsed, with almost no variation.
- **Loss:** Loss immediately drops to near-zero, indicating the model is ignoring the latent space and ignores the mistakes or variation from the original data.
- **Issue:** Without BCE, the model does not reconstruct meaningful outputs.


In [None]:
# Generate samples from a 2D Gaussian Grid and visualize reconstructions
grid_x = np.linspace(-3, 3, 10)
grid_y = np.linspace(-3, 3, 10)

grid_z = torch.tensor(np.array([[x, y] for x in grid_x for y in grid_y]), dtype=torch.float32).to(device)
with torch.no_grad():
    reconstructions = model.decode(grid_z).cpu().numpy()

fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(reconstructions[i].reshape(28, 28), cmap='gray')
    ax.axis('off')
plt.suptitle("Reconstructed Images from 2D Gaussian Grid Samples; when NO BCE but with KL")
plt.show()

## Observations from the above grid
- The reconstructed images appear to be random noise instead of meaningful structures. This suggests that the decoder has not learned to generate structured outputs..
- Without BCE  the model lacks the incentive to reconstruct meaningful images, leading to poor outputs.

In [None]:
# Loss function with  BCE but not with KL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train and visualize latent space
for epoch in range(epochs):
    train_loss = train(KL = False)
    test_loss, latent_means, labels = test(KL = False)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

# Plot latent space visualization
plt.figure(figsize=(8, 6))
plt.scatter(latent_means[:, 0], latent_means[:, 1], c=labels, cmap='viridis', alpha=0.7)
plt.colorbar()
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Visualization of  Only with BCE")
plt.show()

## **Reconstruction Loss (BCE/MSE)**
### **Purpose:**
- Ensures that the reconstructed output is as close as possible to the original input.
- Encourages the decoder to learn meaningful features from the latent space.

### **Effects:**
- **High BCE loss** → Poor reconstruction quality.
- **Low BCE loss** → The model accurately reconstructs inputs.
- If only BCE is used (without KL), the model behaves like a regular Autoencoder, and the latent space may lack smoothness and structure.

### **Observations from BCE-Only Results:**
- **Latent Space:** Without KL, clusters form, but they are scattered and not regularized.
- **Loss Curve:** The loss decreases steadily, showing improvement in reconstruction.
- **Issue:** Overfitting can occur, as the latent space is not constrained.


In [None]:
# Generate samples from a 2D Gaussian Grid and visualize reconstructions
grid_x = np.linspace(-3, 3, 10)
grid_y = np.linspace(-3, 3, 10)

grid_z = torch.tensor(np.array([[x, y] for x in grid_x for y in grid_y]), dtype=torch.float32).to(device)
with torch.no_grad():
    reconstructions = model.decode(grid_z).cpu().numpy()

fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(reconstructions[i].reshape(28, 28), cmap='gray')
    ax.axis('off')
plt.suptitle("Reconstructed Images from 2D Gaussian Grid Samples ; when BCE but not with KL")
plt.show()

## Observations from the avove grid
- The images above are better than just KL loss,the reconstructions now resemble actual digits.
- Since KL divergence is missing, the latent space does not follow a structured distribution
- This can cause the sampled images to transition less smoothly or produce out-of-distribution artifacts.
- Some transitions between digits are not smooth, and certain regions in the latent space might not generate meaningful samples.
  

In [None]:
# Loss function with  MSE and KL
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train using recon MSE and visualize latent space 
for epoch in range(epochs):
    train_loss = train("MSE")
    test_loss, latent_means, labels = test("MSE")
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

# Plot latent space visualization
plt.figure(figsize=(8, 6))
plt.scatter(latent_means[:, 0], latent_means[:, 1], c=labels, cmap='viridis', alpha=0.7)
plt.colorbar()
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Visualization of Means With MSE and KL")
plt.show()

## **Latent Space Structure**
- The latent space shows **clustered but overlapping digit distributions**.
- This means the **KL divergence successfully regularized the latent space**, ensuring a smooth transition between digits.
- However, some classes are **not entirely separated**, which might cause mixed reconstructions when sampling from certain regions.
- The overall structure seems to be **more continuous**, making **interpolation more reliable** than the BCE-only case.


In [None]:
# Generate samples from a 2D Gaussian Grid and visualize reconstructions
grid_x = np.linspace(-3, 3, 10)
grid_y = np.linspace(-3, 3, 10)

grid_z = torch.tensor(np.array([[x, y] for x in grid_x for y in grid_y]), dtype=torch.float32).to(device)
with torch.no_grad():
    reconstructions = model.decode(grid_z).cpu().numpy()

fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(reconstructions[i].reshape(28, 28), cmap='gray')
    ax.axis('off')
plt.suptitle("Reconstructed Images from 2D Gaussian Grid Samples when ; MSE and KL")
plt.show()

## **Reconstruction Quality**
- The reconstructions are fairly good images
- This is expected since **MSE loss tends to produce smoother images** compared to BCE loss, which focuses more on pixel-wise classification.
- The transitions between digits are smoother than the **BCE-only case**, showing **better latent space regularization**.