CelebA dataset stored on student google drive. Need to mount drive in order to access

In [None]:
from google.colab import drive
drive.mount('/content/drive')

For some reason, torchmetrics wouldn't be recognized by Colab, so I needed to write this line

In [None]:
%pip install --upgrade torchmetrics[image] torch-fidelity

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

print("torchmetrics is available:", FrechetInceptionDistance, InceptionScore)

Entire VGAN code

In [None]:
# -------------------------------------------------------------------------------
#  Import Libraries
# -------------------------------------------------------------------------------

import os, glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import matplotlib.pyplot as plt
from PIL import Image



# -------------------------------------------------------------------------------
#  Hyperparameters
# -------------------------------------------------------------------------------


#Directory to save VGAN results
save_root = "/content/drive/MyDrive/celeb/Results_VGAN"
os.makedirs(save_root, exist_ok=True)

#For sweeping uncomment the list and for loop
#lr_list = [2e-4, 2e-5, 2e-6] -> Chosen lr = 2e-4
#z_list  = [64, 100, 128] -> Chosen z-dim = 100
lr = 2e-4
z_dim = 100
batch_size = 64
img_size   = 64
img_dim    = 3 * img_size * img_size
#Swept through hyperparameters at 50 epochs to get results, keeping at 50 epochs
epochs     = 50

# -------------------------------------------------------------------------------
#  Classes
# -------------------------------------------------------------------------------

#Loads in CelebA dataset from Path
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        pattern = os.path.join(root_dir, "**", "*.jpg")
        self.paths = glob.glob(pattern, recursive=True)
        if not self.paths:
            raise RuntimeError(f"No images found: {pattern}")
        print(f"Found {len(self.paths)} images")
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.transform(img) if self.transform else img


#Generator Class
class Generator(nn.Module):
    def __init__(self, z_dim=100, hidden=256, img_dim=3*64*64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, hidden),
            nn.ReLU(False),
            nn.Linear(hidden, hidden*2),
            nn.ReLU(False),
            nn.Linear(hidden*2, img_dim),
            nn.Tanh()
        )
    def forward(self, z): return self.net(z)


#Discriminator Class
class Discriminator(nn.Module):
    def __init__(self, img_dim=3*64*64, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, hidden*2),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(hidden*2, hidden),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(hidden, 1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.net(x)


#Transforms images
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
    transforms.Lambda(lambda x: x.view(-1)),
])

#Dataset pathway specified here: had three locations celeb2500, celeb5000, and celeb10000
dataset = CelebADataset('/content/drive/MyDrive/celeb/celeb10000', transform=transform)
loader  = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                     num_workers=os.cpu_count(), pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------------------------------------------------------
#  Main Loop
# -------------------------------------------------------------------------------


#Uncomment for loops for hyperparameter testing
#for lr in lr_list:
    #for z_dim in z_list:
print(f"\n=== RUN: lr={lr}, z_dim={z_dim} ===")
run_dir = os.path.join(save_root, f"lr_{lr}_z_{z_dim}")
os.makedirs(run_dir, exist_ok=True)

# Build models
G = Generator(z_dim=z_dim, img_dim=img_dim).to(device)
D = Discriminator(img_dim=img_dim).to(device)
#Optimize
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
#Metrics
criterion   = nn.BCELoss()
fid_metric  = FrechetInceptionDistance(feature=2048).to(device)
is_metric   = InceptionScore(feature=2048).to(device)

# Store Scores
loss_D_list, loss_G_list = [], []
fid_scores, inception_scores = [], []

# Training Loop
for epoch in range(1, epochs+1):
    G.train(); D.train()
    running_D, running_G = 0.0, 0.0

    for real_flat in loader:
        real = real_flat.view(-1, img_dim).to(device)
        bsz  = real.size(0)
        real_labels = torch.ones(bsz,1, device=device)
        fake_labels = torch.zeros(bsz,1, device=device)

        # Discriminator step
        D.zero_grad()
        loss_D_real = criterion(D(real), real_labels)
        fake_flat   = G(torch.randn(bsz, z_dim, device=device))
        loss_D_fake = criterion(D(fake_flat.detach()), fake_labels)
        (loss_D_real + loss_D_fake).backward()
        opt_D.step()

        # Generator step
        G.zero_grad()
        loss_G = criterion(D(fake_flat), real_labels)
        loss_G.backward()
        opt_G.step()

        running_D += (loss_D_real + loss_D_fake).item()
        running_G += loss_G.item()

    # Compute Averages
    avg_D = running_D / len(loader)
    avg_G = running_G / len(loader)
    loss_D_list.append(avg_D)
    loss_G_list.append(avg_G)

    G.eval()
    # Prepare real_uint8
    real_list = []
    for flat in loader:
        real_list.append(flat)
        if len(real_list)*batch_size >= 1000:
            break
    real_eval = torch.cat(real_list)[:1000].to(device)
    real_eval = ((real_eval + 1)/2).view(-1,3,img_size,img_size)
    real_uint8 = (real_eval * 255).to(torch.uint8)

    # Fake_uint8 + grid
    with torch.no_grad():
        noise     = torch.randn(1000, z_dim, device=device)
        fake_flat = G(noise)
        fake_eval = ((fake_flat.view(-1,3,img_size,img_size).cpu() + 1)/2)
        fake_uint8 = (fake_eval * 255).to(torch.uint8)
        grid = make_grid(fake_eval, nrow=8, padding=2, normalize=True)

        # uncommented for speed
        # save this epoch's generated grid
        # save_image(grid, os.path.join(run_dir, "fake_epoch_{epoch}.png"))

    # Compute FID
    fid_metric.reset()
    fid_metric.update(real_uint8, real=True)
    fid_metric.update(fake_uint8.to(device), real=False)
    fid_val = fid_metric.compute().item()
    fid_scores.append(fid_val)

    # Compute Inception Score
    is_metric.reset()
    is_metric.update(fake_uint8.to(device))
    is_val, _ = is_metric.compute()
    inception_scores.append(is_val.item())

    # Print all metrics
    print(f"Epoch {epoch}/{epochs} | "
          f"D_loss {avg_D:.3f} | G_loss {avg_G:.3f} | "
          f"FID {fid_val:.3f} | IS {is_val:.3f}")

# -------------------------------------------------------------------------
#  Plot and save graphs
# -------------------------------------------------------------------------

#Change names based on dataset size

epochs_range = list(range(1, epochs+1))

# FID curve
plt.figure()
plt.plot(epochs_range, fid_scores)
plt.title(f"FID for 10000 Images")
plt.xlabel("Epoch"); plt.ylabel("FID"); plt.grid(True)
plt.savefig(os.path.join(run_dir, "fid10000.png"))
plt.close()

# Inception Score curve
plt.figure()
plt.plot(epochs_range, inception_scores)
plt.title(f"Inception Score for 10000 Images")
plt.xlabel("Epoch"); plt.ylabel("IS"); plt.grid(True)
plt.savefig(os.path.join(run_dir, "is10000.png"))
plt.close()

# Loss curves
plt.figure()
plt.plot(epochs_range, loss_G_list, label="Gen Loss")
plt.plot(epochs_range, loss_D_list, label="Disc Loss")
plt.title(f"Losses for 10000 Images")
plt.xlabel("Epoch"); plt.ylabel("BCE Loss")
plt.legend(); plt.grid(True)
plt.savefig(os.path.join(run_dir, "losses10000.png"))
plt.close()


# ---------------------------------------------------------------------
#  Generate final Image
# ---------------------------------------------------------------------
with torch.no_grad():
    n_samples = 64
    noise     = torch.randn(n_samples, z_dim, device=device)
    fake_flat = G(noise)
    fake_imgs = ((fake_flat.view(-1,3,img_size,img_size).cpu() + 1)/2)

final_path = os.path.join(run_dir, "final_samples10000.png")
save_image(fake_imgs, final_path, nrow=8, padding=2, normalize=True)

# inline display
grid = make_grid(fake_imgs, nrow=8, padding=2, normalize=True)
plt.figure(figsize=(6,6))
plt.imshow(grid.permute(1,2,0))
plt.title(f"Final 8×8 Samples at 10000 Images")
plt.axis("off")
plt.show()
