<a href="https://colab.research.google.com/github/goyalpramod/paper_implementations/blob/main/Autoencoder_from_scrath.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Code implementation of the following paper -> [Tutorial on Variational Autoencoders](https://arxiv.org/pdf/1606.05908)
# or this [link](https://www.cs.toronto.edu/~hinton/absps/science.pdf)

Code implementation of the following paper -> [Tutorial on Variational Autoencoders](https://arxiv.org/pdf/1606.05908)
or this [link](https://www.cs.toronto.edu/~hinton/absps/science.pdf)

Consider reading the following blogs to better grasp the idea
* [Lil'log's Blog](https://lilianweng.github.io/posts/2018-08-12-vae/#vae-variational-autoencoder)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load MNIST Dataset
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

# Create DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True
)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()

        # First dense layer
        self.linear1 = nn.Linear(input_dim, hidden_dim)

        # Layer for mean
        self.mean_layer = nn.Linear(hidden_dim, latent_dim)

        # Layer for variance
        self.logvar_layer = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        # x shape: [batch_size, 1, 28, 28]

        # Flatten the input
        x = x.view(x.size(0), -1)  # Now: [batch_size, 784]

        # First dense layer with ReLU
        x = F.relu(self.linear1(x))

        # Get mean and logvar
        mean = self.mean_layer(x)
        logvar = self.logvar_layer(x)

        return mean, logvar

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()

        # Initialize encoder
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)

        # Initialize decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Since MNIST pixel values are between 0 and 1
        )

    def reparameterize(self, mean, logvar):
        """
        Reparameterization trick to sample from N(mean, var) from N(0,1).
        :param mean: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)  # standard deviation
        eps = torch.randn_like(std)  # `randn_like` as we need the same size
        return mean + eps * std

    def forward(self, x):
        # Get mean and logvar from encoder
        mean, logvar = self.encoder(x)

        # Reparameterization
        z = self.reparameterize(mean, logvar)

        # Decode
        reconstruction = self.decoder(z)

        # Reshape reconstruction
        reconstruction = reconstruction.view(-1, 1, 28, 28)

        return reconstruction, mean, logvar

In [None]:
def loss_function(recon_x, x, mean, logvar):
    """
    Calculate VAE loss = reconstruction loss + KL divergence
    :param recon_x: reconstructed input
    :param x: original input
    :param mean: mean of the latent distribution
    :param logvar: log variance of the latent distribution
    """
    # Reconstruction loss (Binary Cross Entropy)
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # KL divergence loss
    # For two Gaussians: N(mean, std) and N(0, 1)
    # KL = 0.5 * sum(1 + log(std^2) - mean^2 - std^2)
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

    return BCE + KLD

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model parameters
input_dim = 784  # 28x28 images
hidden_dim = 400
latent_dim = 20
num_epochs = 50
batch_size = 128
learning_rate = 1e-3

# Initialize model and optimizer
model = VAE(input_dim, hidden_dim, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Lists to store losses for plotting
train_losses = []
bce_losses = []
kld_losses = []

In [None]:
def train_epoch(epoch):
    model.train()
    total_loss = 0
    total_bce = 0
    total_kld = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        recon_batch, mean, logvar = model(data)

        # Calculate loss
        recon_loss = F.binary_cross_entropy(recon_batch, data.view(-1, 784), reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        loss = recon_loss + kld_loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Track losses
        total_loss += loss.item()
        total_bce += recon_loss.item()
        total_kld += kld_loss.item()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')

    # Average losses for the epoch
    avg_loss = total_loss / len(train_loader.dataset)
    avg_bce = total_bce / len(train_loader.dataset)
    avg_kld = total_kld / len(train_loader.dataset)

    return avg_loss, avg_bce, avg_kld

# Training visualization function
def plot_losses(train_losses, bce_losses, kld_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Total Loss')
    plt.plot(bce_losses, label='Reconstruction Loss')
    plt.plot(kld_losses, label='KL Divergence')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
# Training loop
for epoch in range(1, num_epochs + 1):
    loss, bce, kld = train_epoch(epoch)
    train_losses.append(loss)
    bce_losses.append(bce)
    kld_losses.append(kld)

    # Plot every 10 epochs
    if epoch % 10 == 0:
        plot_losses(train_losses, bce_losses, kld_losses)

In [None]:
def visualize_reconstruction(model, data):
    with torch.no_grad():
        # Get the first 8 images
        images = data[:8]

        # Reconstruct images
        recon, _, _ = model(images)

        # Plot original vs reconstructed
        plt.figure(figsize=(12, 4))
        for i in range(8):
            # Original
            plt.subplot(2, 8, i + 1)
            plt.imshow(images[i][0].cpu(), cmap='gray')
            plt.axis('off')

            # Reconstructed
            plt.subplot(2, 8, i + 9)
            plt.imshow(recon[i][0].cpu(), cmap='gray')
            plt.axis('off')
        plt.show()

# Visualize after training
visualize_reconstruction(model, next(iter(train_loader))[0])