In [2]:
!pip install tqdm

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1


In [19]:
import os
import math
from tqdm import tqdm

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, utils, transforms


In [6]:
seed = 42
torch.manual_seed(seed)


<torch._C.Generator at 0x73b0149b6ed0>

In [7]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [9]:

batch_size = 128
latent_dim = 20
hidden_dim = 400
epochs = 10
lr = 1e-3
save_dir = "outputs"
os.makedirs(save_dir, exist_ok=True)

In [11]:
# Dataset

transform = transforms.Compose([
    transforms.ToTensor(), 
])

train_dataset = datasets.MNIST(root="mnist_data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="mnist_data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

100%|██████████| 9.91M/9.91M [02:31<00:00, 65.5kB/s]
100%|██████████| 28.9k/28.9k [00:02<00:00, 11.8kB/s]
100%|██████████| 1.65M/1.65M [00:12<00:00, 135kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.0kB/s]


In [12]:
class VAE(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # for mu and logvar
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        # decoder
        self.fc_dec1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_dec2 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        # x: (B, 784)
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        h = F.relu(self.fc_dec1(z))
        x_hat = torch.sigmoid(self.fc_dec2(h))  # Bernoulli decoder (outputs in [0,1])
        return x_hat

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

In [14]:
model = VAE(input_dim=28*28, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [15]:
def loss_function(recon_x, x, mu, logvar):
    # recon_x and x are flattened (B, 784) with values in [0,1]
    # Reconstruction loss: binary cross entropy (sum over pixels)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')  # sum over batch and pixels

    # KL divergence between q(z|x) ~ N(mu, var) and p(z) ~ N(0,1)
    # KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD, BCE, KLD

In [16]:
def train_epoch(model, loader, optimizer, epoch):
    model.train()
    running_loss = 0.0
    running_bce = 0.0
    running_kld = 0.0
    for batch_idx, (data, _) in enumerate(tqdm(loader, desc=f"Train Epoch {epoch}", leave=False)):
        data = data.to(device)
        data = data.view(data.size(0), -1)  # flatten: (B, 784)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, bce, kld = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_bce += bce.item()
        running_kld += kld.item()

    n = len(loader.dataset)
    avg_loss = running_loss / n
    avg_bce = running_bce / n
    avg_kld = running_kld / n
    print(f"Epoch {epoch} Train: Avg loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, KLD: {avg_kld:.4f}")
    return avg_loss

In [17]:
def eval_epoch(model, loader, epoch, save_images=True):
    model.eval()
    test_loss = 0.0
    test_bce = 0.0
    test_kld = 0.0
    with torch.no_grad():
        for i, (data, _) in enumerate(loader):
            data = data.to(device)
            data = data.view(data.size(0), -1)
            recon, mu, logvar = model(data)
            loss, bce, kld = loss_function(recon, data, mu, logvar)
            test_loss += loss.item()
            test_bce += bce.item()
            test_kld += kld.item()

    n = len(loader.dataset)
    avg_loss = test_loss / n
    avg_bce = test_bce / n
    avg_kld = test_kld / n
    print(f"Epoch {epoch} Eval : Avg loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, KLD: {avg_kld:.4f}")

    # save a few reconstructions and random samples
    if save_images:
        # Take first batch from loader to show reconstructions
        data_sample, _ = next(iter(loader))
        data_sample = data_sample.to(device)[:64]
        with torch.no_grad():
            recon_batch, _, _ = model(data_sample.view(data_sample.size(0), -1))
        # reshape back to (B,1,28,28)
        recons = recon_batch.view(-1, 1, 28, 28)
        originals = data_sample

        # Concatenate originals and reconstructions (first 8x8 grid)
        comparison = torch.cat([originals[:64], recons[:64]])
        utils.save_image(comparison.cpu(), os.path.join(save_dir, f"recon_epoch_{epoch}.png"), nrow=8)
        print(f"Saved reconstructions to {os.path.join(save_dir, f'recon_epoch_{epoch}.png')}")

        # Sample from standard normal and decode
        z = torch.randn(64, latent_dim).to(device)
        samples = model.decode(z).view(-1, 1, 28, 28)
        utils.save_image(samples.cpu(), os.path.join(save_dir, f"samples_epoch_{epoch}.png"), nrow=8)
        print(f"Saved samples to {os.path.join(save_dir, f'samples_epoch_{epoch}.png')}")

    return avg_loss

In [None]:
best_val = float("inf")
for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, epoch)
    val_loss = eval_epoch(model, test_loader, epoch, save_images=True)

    # Save model checkpoint
    ckpt_path = os.path.join(save_dir, f"vae_epoch_{epoch}.pt")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

    if val_loss < best_val:
        best_val = val_loss
print("Training finished.")

                                                                 

Epoch 1 Train: Avg loss: 164.5073, BCE: 149.0275, KLD: 15.4798
Epoch 1 Eval : Avg loss: 127.7196, BCE: 106.3436, KLD: 21.3761
Saved reconstructions to outputs/recon_epoch_1.png
Saved samples to outputs/samples_epoch_1.png
Saved checkpoint to outputs/vae_epoch_1.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 2 Train: Avg loss: 121.4926, BCE: 98.8286, KLD: 22.6640
Epoch 2 Eval : Avg loss: 116.1727, BCE: 91.8081, KLD: 24.3646
Saved reconstructions to outputs/recon_epoch_2.png
Saved samples to outputs/samples_epoch_2.png
Saved checkpoint to outputs/vae_epoch_2.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 3 Train: Avg loss: 114.7024, BCE: 90.4516, KLD: 24.2509
Epoch 3 Eval : Avg loss: 112.1752, BCE: 87.4936, KLD: 24.6816
Saved reconstructions to outputs/recon_epoch_3.png
Saved samples to outputs/samples_epoch_3.png
Saved checkpoint to outputs/vae_epoch_3.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 4 Train: Avg loss: 111.7953, BCE: 87.0726, KLD: 24.7227
Epoch 4 Eval : Avg loss: 109.8246, BCE: 85.2146, KLD: 24.6100
Saved reconstructions to outputs/recon_epoch_4.png
Saved samples to outputs/samples_epoch_4.png
Saved checkpoint to outputs/vae_epoch_4.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 5 Train: Avg loss: 110.0962, BCE: 85.1265, KLD: 24.9697
Epoch 5 Eval : Avg loss: 108.7950, BCE: 84.1767, KLD: 24.6184
Saved reconstructions to outputs/recon_epoch_5.png
Saved samples to outputs/samples_epoch_5.png
Saved checkpoint to outputs/vae_epoch_5.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 6 Train: Avg loss: 108.9465, BCE: 83.8546, KLD: 25.0920
Epoch 6 Eval : Avg loss: 107.7246, BCE: 82.9583, KLD: 24.7663
Saved reconstructions to outputs/recon_epoch_6.png
Saved samples to outputs/samples_epoch_6.png
Saved checkpoint to outputs/vae_epoch_6.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 7 Train: Avg loss: 108.0764, BCE: 82.9074, KLD: 25.1690
Epoch 7 Eval : Avg loss: 107.1779, BCE: 82.2500, KLD: 24.9278
Saved reconstructions to outputs/recon_epoch_7.png
Saved samples to outputs/samples_epoch_7.png
Saved checkpoint to outputs/vae_epoch_7.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 8 Train: Avg loss: 107.4345, BCE: 82.1987, KLD: 25.2357
Epoch 8 Eval : Avg loss: 106.5124, BCE: 81.3601, KLD: 25.1523
Saved reconstructions to outputs/recon_epoch_8.png
Saved samples to outputs/samples_epoch_8.png
Saved checkpoint to outputs/vae_epoch_8.pt
New best model saved to outputs/vae_best.pt


                                                                 

Epoch 9 Train: Avg loss: 106.9065, BCE: 81.6318, KLD: 25.2747
Epoch 9 Eval : Avg loss: 106.2970, BCE: 81.1830, KLD: 25.1140
Saved reconstructions to outputs/recon_epoch_9.png
Saved samples to outputs/samples_epoch_9.png
Saved checkpoint to outputs/vae_epoch_9.pt
New best model saved to outputs/vae_best.pt


                                                                  

Epoch 10 Train: Avg loss: 106.5062, BCE: 81.1640, KLD: 25.3422
Epoch 10 Eval : Avg loss: 105.8104, BCE: 80.7437, KLD: 25.0667
Saved reconstructions to outputs/recon_epoch_10.png
Saved samples to outputs/samples_epoch_10.png
Saved checkpoint to outputs/vae_epoch_10.pt
New best model saved to outputs/vae_best.pt
Training finished.
