CelebA dataset stored on student google drive. Need to mount drive in order to access

In [None]:
# MOUNT GOOGLE DRIVE
from google.colab import drive
drive.mount('/content/drive')

For some reason, torchmetrics wouldn't be recognized by Colab, so I needed to write this line

In [None]:
%pip install --upgrade torchmetrics[image] torch-fidelity

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

print("torchmetrics is available:", FrechetInceptionDistance, InceptionScore)

Main VAE code

In [None]:
# -------------------------------------------------------------------------------
#  Import Libraries
# -------------------------------------------------------------------------------

import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt

# -------------------------------------------------------------------------------
#  Hyperparameters
# -------------------------------------------------------------------------------
batch_size    = 64
learning_rate = 2e-4
epochs        = 50
latent_dim    = 100
image_size    = 64
device        = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -----------------------
# Show Images during training
# -----------------------
def show_images(img_batch, title=None):
    grid = make_grid(img_batch, nrow=8, normalize=True)
    np_grid = grid.permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(8,8))
    plt.imshow(np_grid)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.show()

# -----------------------
# Dataset Load
# -----------------------
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.files = glob(os.path.join(root_dir, '*.jpg'))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

#Transform Images
transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])


#Where Dataset is located on student drive
data_dir      = '/content/drive/MyDrive/celeb/celeb5000'
train_dataset = CelebADataset(data_dir, transform=transform)
train_loader  = DataLoader(train_dataset, batch_size=batch_size,
                           shuffle=True, num_workers=2)

# -------------------------------------------------------------------------------
# VAE Class
# -------------------------------------------------------------------------------
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        # Encoder
        self.enc1 = nn.Conv2d(3, 32, 4, 2, 1)
        self.enc2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.enc3 = nn.Conv2d(64, 128, 4, 2, 1)
        self.enc4 = nn.Conv2d(128, 256, 4, 2, 1)
        self.flat = nn.Flatten()
        self.fc_mu     = nn.Linear(256*4*4, latent_dim)
        self.fc_logvar = nn.Linear(256*4*4, latent_dim)
        # Decoder
        self.fc_dec = nn.Linear(latent_dim, 256*4*4)
        self.dec4   = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.dec3   = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.dec2   = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.dec1   = nn.ConvTranspose2d(32,  3, 4, 2, 1)

    def encode(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = self.flat(x)
        return self.fc_mu(x), self.fc_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std

    def decode(self, z):
        x = F.relu(self.fc_dec(z)).view(-1, 256, 4, 4)
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec2(x))
        return torch.tanh(self.dec1(x))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# -------------------------------------------------------------------------------
# Loss Function
# -------------------------------------------------------------------------------
def vae_loss(recon, x, mu, logvar):
    recon_l = F.mse_loss(recon, x, reduction='sum')
    kld     = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_l + kld

# -------------------------------------------------------------------------------
#  Model, Optimizer, and Metrics
# -------------------------------------------------------------------------------
model         = VAE(latent_dim).to(device)
optimizer     = torch.optim.Adam(model.parameters(), lr=learning_rate)
fid_metric    = FrechetInceptionDistance().to(device)
is_metric     = InceptionScore().to(device)

# Store results for graphs
loss_history      = []
fid_history       = []
inception_history = []
sample_collection = []

# -------------------------------------------------------------------------------
#  Training Loop
# -------------------------------------------------------------------------------
start_time = time.time()
for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0
    for imgs in train_loader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataset)
    loss_history.append(avg_loss)
    print(f'Epoch {epoch:03d}  Avg Loss: {avg_loss:.4f}')

    # Generate samples for this epoch
    model.eval()
    with torch.no_grad():
        z       = torch.randn(64, latent_dim, device=device)
        sample  = model.decode(z)

        # Store raw float samples for later use
        sample_collection.append(sample.cpu())

        # Convert to uint8 for metrics
        sample_uint8 = ((sample.clamp(-1,1) + 1) * 127.5).to(torch.uint8)
        real_batch   = next(iter(train_loader)).to(device)
        real_uint8   = ((real_batch.clamp(-1,1) + 1) * 127.5).to(torch.uint8)

        # Inception Score
        is_metric.update(sample_uint8)
        is_val, _ = is_metric.compute()
        inception_history.append(is_val.item())
        is_metric.reset()

        # FID Score
        fid_metric.update(real_uint8, real=True)
        fid_metric.update(sample_uint8, real=False)
        fid_val = fid_metric.compute()
        fid_history.append(fid_val.item())
        fid_metric.reset()

        #Print result out every Epoch
        print(f"Epoch {epoch}/{epochs} | Avg Loss: {avg_loss:.4f} | "
              f"FID {fid_val:.3f} | IS {is_val:.3f}")

        # Optional: display sample grid inline
        #Can comment in or out
        #show_images(sample, title=f'Epoch {epoch:03d} Samples')

#Calculate time
elapsed = time.time() - start_time
print(f"\nTraining complete in {elapsed/60:.2f} minutes.")

# -------------------------------------------------------------------------------
#  Plot Figures
# -------------------------------------------------------------------------------

# Loss Plot
plt.figure()
plt.plot(range(1, epochs+1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('VAE Training Loss over Epochs')
plt.grid(True)
plt.show()

#IS Plot
plt.figure()
plt.plot(range(1, epochs+1), inception_history)
plt.xlabel('Epoch')
plt.ylabel('Inception Score')
plt.title('Inception Score over Epochs')
plt.grid(True)
plt.show()

#FID Plot
plt.figure()
plt.plot(range(1, epochs+1), fid_history)
plt.xlabel('Epoch')
plt.ylabel('FID Score')
plt.title('FID Score over Epochs')
plt.grid(True)
plt.show()

