In [6]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [7]:
BATCH_SIZE = 128

train_start_epoch = 50

In [8]:
device = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print(f"Using device: {device}")

Using device: mps


In [10]:
from dataset import MorphII_Dataset
import torchvision.transforms as transforms

prepipeline = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


In [11]:
# Encoder: maps image and condition -> latent mean and logvar.
class Encoder(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),   # B x 3 x 64 x 64 -> B x 16 x 32 x 32
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # B x 16 x 32 x 32 -> B x 32 x 16 x 16
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # B x 32 x 16 x 16 -> B x 64 x 8 x 8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # B x 64 x 8 x 8 -> B x 128 x 4 x 4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1)
        )

        # 128 x 4 x 4 = 2048
        self.fc_mu = nn.Linear(2048 + condition_dim, latent_dim)
        self.fc_logvar = nn.Linear(2048 + condition_dim, latent_dim)

    def forward(self, x, condition):
        batch_size = x.size(0)
        x = self.conv(x)              # shape: (B, 128, 4, 4)
        x = x.view(batch_size, -1)    # flatten to (B, 2048)
        x = torch.cat([x, condition], dim=1)  # concatenate condition (B, 2048+condition_dim)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class Decoder(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + condition_dim, 2048)   # B x (latent_dim+condition_dim) -> B x 2048
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 128 x 4 x 4 -> B x 64 x 8 x 8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 64 x 8 x 8 -> B x 32 x 16 x 16
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 32 x 16 x 16 -> B x 16 x 32 x 32
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 16 x 32 x 32 -> B x 3 x 64 x 64
            nn.Tanh()  # Output in [-1, 1]
        )

    def forward(self, z, condition):
        x = torch.cat([z, condition], dim=1)  # shape: (B, latent_dim+condition_dim)
        x = self.fc(x)                        # (B, 2048)
        x = x.view(-1, 128, 4, 4)
        x = self.deconv(x)
        return x

class ConditionalVAE(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(ConditionalVAE, self).__init__()
        self.encoder = Encoder(latent_dim, condition_dim)
        self.decoder = Decoder(latent_dim, condition_dim)

    def forward(self, x, condition):
        mu, logvar = self.encoder(x, condition)
        z = reparameterize(mu, logvar)
        recon_x = self.decoder(z, condition)
        return recon_x, mu, logvar

latent_dim = 100
condition_dim = 1  # only using age, we could expand this to gender and race
model = ConditionalVAE(latent_dim, condition_dim).to(device)
print(model)

ConditionalVAE(
  (encoder): Encoder(
    (conv): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1)
      (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.1)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): LeakyReLU(negative_slope=0.1)
      (9): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): LeakyReLU(negative_slope=0.1)
    )
    (fc_mu): Linear(in_features=2049, out_features=100, bias=True)
    (fc_logvar): Linear(in

In [12]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (recon_loss + kld) / x.size(0)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0008)

In [13]:
def load_checkpoint(model, optimizer, checkpoint_dir, epoch, device):
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    print(f"Loaded checkpoint from {checkpoint_path}")

checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

if train_start_epoch > 0:
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{train_start_epoch}.pth")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"Resumed model from checkpoint: {checkpoint_path}")

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/checkpoint_epoch_50.pth'

In [None]:
num_epochs = 500
train_loss_history = []
val_loss_history = []

checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

for epoch in trange(train_start_epoch, num_epochs, position=0, leave=True, desc="Epochs", unit="epoch"):
    # Training phase
    model.train()
    epoch_train_loss = 0.0
    train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False, position=1)
    for images, conditions in train_progress:
        images = images.to(device)
        conditions = conditions.to(device)

        optimizer.zero_grad()
        recon_images, mu, logvar = model(images, conditions)
        loss = vae_loss(recon_images, images, mu, logvar)
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()
        train_progress.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = epoch_train_loss / len(train_loader)
    train_loss_history.append(avg_train_loss)

    # Validation phase
    model.eval()
    epoch_val_loss = 0.0
    val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
    with torch.no_grad():
        for images, conditions in val_progress:
            images = images.to(device)
            conditions = conditions.to(device)
            recon_images, mu, logvar = model(images, conditions)
            loss = vae_loss(recon_images, images, mu, logvar)
            epoch_val_loss += loss.item()
            val_progress.set_postfix(loss=f"{loss.item():.4f}")

    avg_val_loss = epoch_val_loss / len(val_loader)
    val_loss_history.append(avg_val_loss)

    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
    print(f"Reconstruction Loss: {F.mse_loss(recon_images, images).item():.4f}, KL Divergence Loss: {-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()).item():.4f}")

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save(model.state_dict(), checkpoint_path)

In [None]:
model.eval()
with torch.no_grad():
    for i in range(5):
        img, cond = val_dataset[i]
        img = img.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)
        recon, _, _ = model(img, cond)

        orig_np = (img.squeeze().cpu().numpy().transpose(1,2,0) * 0.5 + 0.5)
        recon_np = (recon.squeeze().cpu().numpy().transpose(1,2,0) * 0.5 + 0.5)

        plt.figure(figsize=(6, 3))
        plt.subplot(1, 2, 1)
        plt.imshow(orig_np)
        plt.title("Original")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(recon_np)
        plt.title("Reconstructed")
        plt.axis("off")
        plt.show()

In [None]:
checkpoint_path = "conditional_vae_morphii.pth"
torch.save(model.state_dict(), checkpoint_path)
print("Model checkpoint saved to", checkpoint_path)