In [105]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import pandas as pd
from torchvision import transforms
from torchvision.utils import save_image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'data/vae'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [106]:
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(
    root='../../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

# Data loader
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size, 
    shuffle=True
)

In [107]:
# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

In [109]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction="sum")
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch[{epoch + 1}/{num_epochs}], Reconstruction Loss: {reconst_loss.item()}, KL Div: {kl_div.item()}")

Epoch[1/15], Reconstruction Loss: 10694.8896484375, KL Div: 1973.07177734375
Epoch[2/15], Reconstruction Loss: 9448.3955078125, KL Div: 2221.28125
Epoch[3/15], Reconstruction Loss: 8144.0048828125, KL Div: 2297.97900390625
Epoch[4/15], Reconstruction Loss: 8486.9521484375, KL Div: 2283.920654296875
Epoch[5/15], Reconstruction Loss: 8251.640625, KL Div: 2412.15673828125
Epoch[6/15], Reconstruction Loss: 8103.5390625, KL Div: 2371.3935546875
Epoch[7/15], Reconstruction Loss: 7980.45361328125, KL Div: 2430.5537109375
Epoch[8/15], Reconstruction Loss: 8062.298828125, KL Div: 2432.346435546875
Epoch[9/15], Reconstruction Loss: 7852.603515625, KL Div: 2373.0703125
Epoch[10/15], Reconstruction Loss: 7743.85107421875, KL Div: 2507.126220703125
Epoch[11/15], Reconstruction Loss: 7822.986328125, KL Div: 2408.00927734375
Epoch[12/15], Reconstruction Loss: 7638.9970703125, KL Div: 2473.6162109375
Epoch[13/15], Reconstruction Loss: 7936.580078125, KL Div: 2498.7265625
Epoch[14/15], Reconstruction L

In [119]:
x, _ = next(iter(data_loader))
x0 = x.view(-1, image_size)
mu, log_var = model.encode(x0)
n_samples = 10

z_samples, x_samples = [], []
for i in range(n_samples):
    z = model.reparameterize(mu, log_var)
    z_df = pd.DataFrame(z.detach().numpy())
    z_df["sample"] = i
    z_samples.append(z_df)
    
pd.concat(z_samples).to_csv(f"{sample_dir}/z_samples.csv", index=False)
pd.DataFrame(x0.detach().numpy()).to_csv(f"{sample_dir}/x0.csv", index=False)

In [122]:
unwrap = lambda z_i: z_i.view(28, -1).detach().numpy()

n_interp = 15
z_interp = torch.zeros(n_interp, z_dim)
for i in range(n_interp):
    z_interp[i, :] = (i / n_interp) * z[0] + (1 - i / n_interp) * z[1]
    
x_hat = model.decode(z_interp)

pd.DataFrame(x_hat.detach().numpy()).to_csv(f"{sample_dir}/x_hat.csv", index=False)
pd.DataFrame(z_interp.detach().numpy()).to_csv(f"{sample_dir}/z_interp.csv", index=False)