## Libraries

In [1]:
from torch import nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np

## VAE Architecture

In [2]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=200, z_dim=20):
        super().__init__()
        # encoding
        self.img2hidden = nn.Linear(input_dim, hidden_dim)
        self.hidden2hidden = nn.Linear(hidden_dim, hidden_dim)
        self.hidden2mu = nn.Linear(hidden_dim, z_dim)
        self.hidden2sigma = nn.Linear(hidden_dim, z_dim)

        # decoder
        self.z2hidden = nn.Linear(z_dim, hidden_dim)
        self.hidden2image = nn.Linear(hidden_dim, input_dim)

        # define ReLU
        self.relu = nn.ReLU()

    def encoder(self, x):
        h = self.relu(self.img2hidden(x))
        h = self.relu(self.hidden2hidden(h))
        mu = self.hidden2mu(h)
        logvar = self.hidden2sigma(h)
        return mu, logvar

    def decoder(self, z):
        h = self.relu(self.z2hidden(z))
        h = self.relu(self.hidden2hidden(h))
        img = self.hidden2image(h)
        img = torch.sigmoid(img)
        return img


    def forward(self, x):
        mu, logvar = self.encoder(x)
        sigma = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(sigma)
        z_reparameterized = mu + sigma * epsilon

        x_reconstructed = self.decoder(z_reparameterized)

        return x_reconstructed, mu, logvar




## Train

In [3]:
device = torch.device("cpu")
input_dim = 784
hidden_dim = 200
z_dim = 20
num_epochs = 60
batch_size = 128
karpathy_constant = 3e-4



dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
random_indices = np.random.choice(len(dataset), size=5000, replace=False)
subset_train_dataset = Subset(dataset, random_indices)

train_loader = DataLoader(
    dataset=subset_train_dataset,batch_size=batch_size,shuffle = True
    )

model = VariationalAutoEncoder(input_dim=input_dim,
                               hidden_dim=hidden_dim,
                               z_dim=z_dim)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = karpathy_constant)
loss_fn = nn.BCELoss(reduction='sum')


for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader))

    epoch_loss = 0

    # not using labels for now
    for i, (x, _) in loop:
        x = x.to(device).view(x.shape[0], input_dim)

        x_reconstructed, mu, logvar = model(x)

        recontruction_loss = loss_fn(x_reconstructed, x)

        # kl divergence
        # minimizing the same as negative
        # pushes towards gaussian
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        loss = recontruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        loop.set_postfix(loss=loss.item())

    avg_epoch_loss = epoch_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_epoch_loss:.2f}")


# save the model
torch.save(model.state_dict(), "vae_model.pth")

100%|██████████| 26.4M/26.4M [00:02<00:00, 12.7MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 201kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.78MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 22.1MB/s]
100%|██████████| 40/40 [00:01<00:00, 36.42it/s, loss=3.28e+3]


Epoch [1/60], Average Loss: 487.13


100%|██████████| 40/40 [00:00<00:00, 41.55it/s, loss=3.16e+3]


Epoch [2/60], Average Loss: 388.23


100%|██████████| 40/40 [00:00<00:00, 42.17it/s, loss=2.47e+3]


Epoch [3/60], Average Loss: 368.31


100%|██████████| 40/40 [00:00<00:00, 43.00it/s, loss=2.34e+3]


Epoch [4/60], Average Loss: 330.32


100%|██████████| 40/40 [00:01<00:00, 37.06it/s, loss=2.61e+3]


Epoch [5/60], Average Loss: 312.89


100%|██████████| 40/40 [00:01<00:00, 31.93it/s, loss=2.5e+3]


Epoch [6/60], Average Loss: 306.51


100%|██████████| 40/40 [00:01<00:00, 34.95it/s, loss=2.27e+3]


Epoch [7/60], Average Loss: 302.39


100%|██████████| 40/40 [00:01<00:00, 28.81it/s, loss=1.9e+3]


Epoch [8/60], Average Loss: 298.90


100%|██████████| 40/40 [00:01<00:00, 34.72it/s, loss=2.4e+3]


Epoch [9/60], Average Loss: 295.95


100%|██████████| 40/40 [00:00<00:00, 42.45it/s, loss=2.13e+3]


Epoch [10/60], Average Loss: 293.14


100%|██████████| 40/40 [00:00<00:00, 42.28it/s, loss=1.98e+3]


Epoch [11/60], Average Loss: 290.56


100%|██████████| 40/40 [00:00<00:00, 42.96it/s, loss=2.34e+3]


Epoch [12/60], Average Loss: 287.62


100%|██████████| 40/40 [00:00<00:00, 40.70it/s, loss=2.35e+3]


Epoch [13/60], Average Loss: 284.86


100%|██████████| 40/40 [00:00<00:00, 40.85it/s, loss=2.51e+3]


Epoch [14/60], Average Loss: 282.29


100%|██████████| 40/40 [00:00<00:00, 40.56it/s, loss=2.47e+3]


Epoch [15/60], Average Loss: 280.36


100%|██████████| 40/40 [00:01<00:00, 38.75it/s, loss=2.19e+3]


Epoch [16/60], Average Loss: 278.87


100%|██████████| 40/40 [00:01<00:00, 30.26it/s, loss=2.06e+3]


Epoch [17/60], Average Loss: 277.20


100%|██████████| 40/40 [00:01<00:00, 29.63it/s, loss=2.05e+3]


Epoch [18/60], Average Loss: 275.85


100%|██████████| 40/40 [00:00<00:00, 40.09it/s, loss=2.59e+3]


Epoch [19/60], Average Loss: 274.44


100%|██████████| 40/40 [00:01<00:00, 38.75it/s, loss=2.35e+3]


Epoch [20/60], Average Loss: 273.04


100%|██████████| 40/40 [00:01<00:00, 38.34it/s, loss=2.62e+3]


Epoch [21/60], Average Loss: 271.62


100%|██████████| 40/40 [00:01<00:00, 39.06it/s, loss=2.14e+3]


Epoch [22/60], Average Loss: 270.28


100%|██████████| 40/40 [00:01<00:00, 39.25it/s, loss=1.96e+3]


Epoch [23/60], Average Loss: 269.06


100%|██████████| 40/40 [00:00<00:00, 40.07it/s, loss=2.23e+3]


Epoch [24/60], Average Loss: 268.16


100%|██████████| 40/40 [00:01<00:00, 39.35it/s, loss=2.1e+3]


Epoch [25/60], Average Loss: 267.08


100%|██████████| 40/40 [00:01<00:00, 38.48it/s, loss=2.15e+3]


Epoch [26/60], Average Loss: 266.35


100%|██████████| 40/40 [00:01<00:00, 39.73it/s, loss=2.06e+3]


Epoch [27/60], Average Loss: 265.63


100%|██████████| 40/40 [00:01<00:00, 32.91it/s, loss=2.13e+3]


Epoch [28/60], Average Loss: 264.92


100%|██████████| 40/40 [00:01<00:00, 28.57it/s, loss=1.92e+3]


Epoch [29/60], Average Loss: 264.29


100%|██████████| 40/40 [00:01<00:00, 35.61it/s, loss=2.29e+3]


Epoch [30/60], Average Loss: 263.78


100%|██████████| 40/40 [00:01<00:00, 39.70it/s, loss=2.23e+3]


Epoch [31/60], Average Loss: 263.19


100%|██████████| 40/40 [00:01<00:00, 38.72it/s, loss=2.01e+3]


Epoch [32/60], Average Loss: 262.86


100%|██████████| 40/40 [00:00<00:00, 40.92it/s, loss=2.38e+3]


Epoch [33/60], Average Loss: 262.18


100%|██████████| 40/40 [00:01<00:00, 35.07it/s, loss=2.2e+3]


Epoch [34/60], Average Loss: 261.98


100%|██████████| 40/40 [00:01<00:00, 34.80it/s, loss=2.12e+3]


Epoch [35/60], Average Loss: 261.36


100%|██████████| 40/40 [00:01<00:00, 34.43it/s, loss=1.75e+3]


Epoch [36/60], Average Loss: 260.77


100%|██████████| 40/40 [00:00<00:00, 40.08it/s, loss=1.87e+3]


Epoch [37/60], Average Loss: 260.62


100%|██████████| 40/40 [00:01<00:00, 39.62it/s, loss=1.98e+3]


Epoch [38/60], Average Loss: 260.16


100%|██████████| 40/40 [00:01<00:00, 33.77it/s, loss=1.91e+3]


Epoch [39/60], Average Loss: 259.94


100%|██████████| 40/40 [00:01<00:00, 29.27it/s, loss=2.15e+3]


Epoch [40/60], Average Loss: 259.67


100%|██████████| 40/40 [00:01<00:00, 32.74it/s, loss=2.05e+3]


Epoch [41/60], Average Loss: 259.24


100%|██████████| 40/40 [00:01<00:00, 39.24it/s, loss=2.16e+3]


Epoch [42/60], Average Loss: 258.77


100%|██████████| 40/40 [00:01<00:00, 38.81it/s, loss=2.35e+3]


Epoch [43/60], Average Loss: 258.37


100%|██████████| 40/40 [00:01<00:00, 39.96it/s, loss=1.95e+3]


Epoch [44/60], Average Loss: 258.22


100%|██████████| 40/40 [00:00<00:00, 40.35it/s, loss=2.13e+3]


Epoch [45/60], Average Loss: 258.07


100%|██████████| 40/40 [00:00<00:00, 40.55it/s, loss=1.69e+3]


Epoch [46/60], Average Loss: 257.63


100%|██████████| 40/40 [00:00<00:00, 41.13it/s, loss=1.93e+3]


Epoch [47/60], Average Loss: 257.45


100%|██████████| 40/40 [00:00<00:00, 40.95it/s, loss=2.28e+3]


Epoch [48/60], Average Loss: 257.06


100%|██████████| 40/40 [00:01<00:00, 39.21it/s, loss=1.98e+3]


Epoch [49/60], Average Loss: 256.70


100%|██████████| 40/40 [00:00<00:00, 40.69it/s, loss=2.22e+3]


Epoch [50/60], Average Loss: 256.68


100%|██████████| 40/40 [00:01<00:00, 32.08it/s, loss=2.03e+3]


Epoch [51/60], Average Loss: 256.27


100%|██████████| 40/40 [00:01<00:00, 27.73it/s, loss=2.04e+3]


Epoch [52/60], Average Loss: 255.91


100%|██████████| 40/40 [00:00<00:00, 40.14it/s, loss=1.96e+3]


Epoch [53/60], Average Loss: 255.79


100%|██████████| 40/40 [00:00<00:00, 40.16it/s, loss=2.28e+3]


Epoch [54/60], Average Loss: 255.45


100%|██████████| 40/40 [00:00<00:00, 42.60it/s, loss=1.71e+3]


Epoch [55/60], Average Loss: 255.16


100%|██████████| 40/40 [00:00<00:00, 42.16it/s, loss=2.05e+3]


Epoch [56/60], Average Loss: 255.01


100%|██████████| 40/40 [00:00<00:00, 41.31it/s, loss=2.44e+3]


Epoch [57/60], Average Loss: 254.71


100%|██████████| 40/40 [00:00<00:00, 41.19it/s, loss=2.08e+3]


Epoch [58/60], Average Loss: 254.48


100%|██████████| 40/40 [00:01<00:00, 39.70it/s, loss=1.87e+3]


Epoch [59/60], Average Loss: 254.23


100%|██████████| 40/40 [00:00<00:00, 41.14it/s, loss=2.1e+3]

Epoch [60/60], Average Loss: 253.93





## Inference

In [4]:
# Load the trained model
model = VariationalAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, z_dim=z_dim).to(device)
model.load_state_dict(torch.load("vae_model.pth"))  # Make sure to save your model at training time
model.eval()

# Load MNIST test data
test_dataset = datasets.FashionMNIST(root=".data/", train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Get one batch
test_batch = next(iter(test_loader))[0].to(device)
test_batch_flat = test_batch.view(test_batch.size(0), -1)

# Inference
with torch.no_grad():
    recon_batch, _, _ = model(test_batch_flat)

# Reshape for visualization
recon_batch = recon_batch.view(-1, 1, 28, 28)
original = test_batch.view(-1, 1, 28, 28)

# Concatenate and save
comparison = torch.cat([original, recon_batch])
save_image(comparison, "inference_reconstruction.png", nrow=8)

100%|██████████| 26.4M/26.4M [00:01<00:00, 13.4MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 201kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.72MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 22.4MB/s]
