In [None]:
!ls

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os

SAVE_DIR = "/content/drive/MyDrive/VAE_MNIST_Results"

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(SAVE_DIR + "/checkpoints", exist_ok=True)
os.makedirs(SAVE_DIR + "/generated", exist_ok=True)
os.makedirs(SAVE_DIR + "/recon", exist_ok=True)

print("Folders created in Drive!")




In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
latent_dim = 20   # perfect for MNIST

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(-1, 28*28)
        h = self.encoder(x)

        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)

        z = self.reparameterize(mu, logvar)

        out = self.decoder_fc(z)
        out = out.view(-1, 1, 28, 28)

        return out, mu, logvar


In [None]:
def vae_loss(recon, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss


In [None]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 50   # MNIST converges beautifully in 40â€“60 epochs
losses = []


In [None]:
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for imgs, _ in train_loader:
        imgs = imgs.to(device)

        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataset)
    losses.append(avg_loss)

    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.2f}")

    # Save checkpoint
    torch.save(model.state_dict(),
               f"{SAVE_DIR}/checkpoints/vae_epoch_{epoch+1}.pth")

    # Save reconstruction images
    with torch.no_grad():
        sample_imgs = recon[:8].detach().cpu()
        original_imgs = imgs[:8].detach().cpu()

    plt.figure(figsize=(8,4))

    for i in range(8):
        plt.subplot(2,8,i+1)
        plt.imshow(original_imgs[i][0], cmap="gray")
        plt.axis("off")

        plt.subplot(2,8,i+9)
        plt.imshow(sample_imgs[i][0], cmap="gray")
        plt.axis("off")

    plt.savefig(f"{SAVE_DIR}/recon/recon_epoch_{epoch+1}.png")
    plt.close()


In [None]:
model.eval()

with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    generated = model.decoder_fc(z).view(-1,1,28,28)

plt.figure(figsize=(8,8))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(generated[i][0].cpu(), cmap="gray")
    plt.axis("off")

plt.savefig(f"{SAVE_DIR}/generated/generated_digits.png")
plt.show()


In [None]:
plt.plot(losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE Training Loss on MNIST")
plt.savefig(f"{SAVE_DIR}/loss_curve.png")
plt.show()


In [None]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # smaller LR for fine tuning

checkpoint_path = f"{SAVE_DIR}/checkpoints/vae_epoch_50.pth"
model.load_state_dict(torch.load(checkpoint_path))

model.train()
print("Loaded checkpoint from epoch 50, resuming training...")


In [None]:
more_epochs = 50
start_epoch = 50   # last completed epoch

for epoch in range(more_epochs):
    total_loss = 0
    model.train()

    for imgs, _ in train_loader:
        imgs = imgs.to(device)

        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataset)
    current_epoch = start_epoch + epoch + 1

    print(f"Epoch {current_epoch}/100 | Loss: {avg_loss:.2f}")

    # Save checkpoint
    torch.save(model.state_dict(),
               f"{SAVE_DIR}/checkpoints/vae_epoch_{current_epoch}.pth")

    # Save reconstruction images (original vs reconstructed)
    with torch.no_grad():
        sample_recon = recon[:8].detach().cpu()
        sample_orig = imgs[:8].detach().cpu()

    plt.figure(figsize=(8,4))
    for i in range(8):
        # original
        plt.subplot(2,8,i+1)
        plt.imshow(sample_orig[i][0], cmap="gray")
        plt.axis("off")

        # reconstructed
        plt.subplot(2,8,i+9)
        plt.imshow(sample_recon[i][0], cmap="gray")
        plt.axis("off")

    plt.savefig(f"{SAVE_DIR}/recon/recon_epoch_{current_epoch}.png")
    plt.close()


In [None]:
model.eval()

with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    generated = model.decoder_fc(z).view(-1,1,28,28)

plt.figure(figsize=(8,8))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(generated[i][0].cpu(), cmap="gray")
    plt.axis("off")

plt.savefig(f"{SAVE_DIR}/generated/generated_digits_epoch100.png")
plt.show()


In [None]:
imgs, _ = next(iter(train_loader))
imgs = imgs.to(device)

with torch.no_grad():
    recon, mu, logvar = model(imgs)

orig = imgs[:8].cpu()
recon_img = recon[:8].detach().cpu()

plt.figure(figsize=(10,4))
for i in range(8):
    plt.subplot(2,8,i+1)
    plt.imshow(orig[i][0], cmap="gray")
    plt.axis("off")

    plt.subplot(2,8,i+9)
    plt.imshow(recon_img[i][0], cmap="gray")
    plt.axis("off")

plt.show()


In [None]:
plt.plot(losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE Training Loss on MNIST")
plt.savefig(f"{SAVE_DIR}/loss_curve.png")
plt.show()


In [None]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # small LR for fine tuning

checkpoint_path = f"{SAVE_DIR}/checkpoints/vae_epoch_100.pth"
model.load_state_dict(torch.load(checkpoint_path))

model.train()
print("Loaded checkpoint from epoch 100. Continuing training...")


In [None]:
more_epochs = 50
start_epoch = 100   # last completed epoch

for epoch in range(more_epochs):
    model.train()
    total_loss = 0

    for imgs, _ in train_loader:
        imgs = imgs.to(device)

        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataset)
    current_epoch = start_epoch + epoch + 1

    print(f"Epoch {current_epoch}/150 | Loss: {avg_loss:.2f}")

    # Save checkpoint
    torch.save(model.state_dict(),
               f"{SAVE_DIR}/checkpoints/vae_epoch_{current_epoch}.pth")

    # Save reconstruction images (original vs reconstructed)
    with torch.no_grad():
        sample_recon = recon[:8].detach().cpu()
        sample_orig = imgs[:8].detach().cpu()

    import matplotlib.pyplot as plt
    plt.figure(figsize=(8,4))
    for i in range(8):
        # original
        plt.subplot(2,8,i+1)
        plt.imshow(sample_orig[i][0], cmap="gray")
        plt.axis("off")

        # reconstructed
        plt.subplot(2,8,i+9)
        plt.imshow(sample_recon[i][0], cmap="gray")
        plt.axis("off")

    plt.savefig(f"{SAVE_DIR}/recon/recon_epoch_{current_epoch}.png")
    plt.close()


In [None]:
model.eval()

with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    generated = model.decoder_fc(z).view(-1,1,28,28)

plt.figure(figsize=(8,8))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(generated[i][0].cpu(), cmap="gray")
    plt.axis("off")

plt.savefig(f"{SAVE_DIR}/generated/generated_digits_epoch150.png")
plt.show()


In [None]:
imgs, _ = next(iter(train_loader))
imgs = imgs.to(device)

with torch.no_grad():
    recon, mu, logvar = model(imgs)

orig = imgs[:8].cpu()
recon_img = recon[:8].detach().cpu()

plt.figure(figsize=(10,4))
for i in range(8):
    plt.subplot(2,8,i+1)
    plt.imshow(orig[i][0], cmap="gray")
    plt.axis("off")

    plt.subplot(2,8,i+9)
    plt.imshow(recon_img[i][0], cmap="gray")
    plt.axis("off")

plt.show()
