In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.notebook import tqdm, trange
import lpips
import matplotlib.pyplot as plt

In [None]:
DEBUG_HOOK = True
LOAD_FROM_CHECKPOINT = False
DEBUG_CONSTANT_AGE = False

beta = 0.5
batch_size = 64

In [None]:
torch.set_float32_matmul_precision("high")

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            self.conv_block(3, 64),     # B x 64 x 64 x 64
            self.conv_block(64, 128),   # B x 128 x 32 x 32
            self.conv_block(128, 128),  # B x 128 x 16 x 16
            self.conv_block(128, 128),  # B x 128 x 8 x 8
            self.conv_block(128, 128),  # B x 128 x 4 x 4
            self.conv_block(128, 128),  # B x 128 x 2 x 2
        )
        self.flatten = nn.Flatten()

        self.age_fc = nn.Linear(1, 32)
        nn.init.uniform_(self.age_fc.weight, a=-0.01, b=0.01)

        # nn.init.uniform_(self.age_fc.weight, a=-0.01, b=0.01)

        self.fc = nn.Linear(128 * 2 * 2 + 32, 128 * 2 * 2)
        self.z_mu = nn.Linear(128 * 2 * 2, 200)
        self.z_logvar = nn.Linear(128 * 2 * 2, 200)

    @staticmethod
    def conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(1e-2)
        )

    def forward(self, x, age):
        x = self.conv_layers(x)
        x = self.flatten(x)
        # Scale down the age embedding to keep its range small
        age_embed = torch.relu(self.age_fc(age)) * 0.1
        x = torch.cat([x, age_embed], dim=1)
        x = torch.relu(self.fc(x))
        return self.z_mu(x), self.z_logvar(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.age_fc = nn.Linear(1, 32)
        nn.init.uniform_(self.age_fc.weight, a=-0.01, b=0.01)

        # nn.init.uniform_(self.age_fc.weight, a=-0.01, b=0.01)

        # 200-dimensional latent space + 32-dimensional age embedding
        self.fc = nn.Linear(200 + 32, 256 * 4 * 4)  # Instead of 128 * 2 * 2
        self.deconv_layers = nn.Sequential(
            self.deconv_block(256, 128),  # B x 128 x 4 x 4
            self.deconv_block(128, 128),  # B x 128 x 8 x 8
            self.deconv_block(128, 128),  # B x 128 x 16 x 16
            self.deconv_block(128, 64),   # B x 64 x 32 x 32
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),  # Now outputs 128x128
            nn.Sigmoid()
        )

    @staticmethod
    def deconv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3,
                               stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(1e-2)
        )

    def forward(self, z, age):
        age_embed = torch.relu(self.age_fc(age)) * 0.1
        z = torch.cat([z, age_embed], dim=1)
        x = self.fc(z).view(-1, 256, 4, 4)  # More expressivity
        return self.deconv_layers(x)

In [None]:
class Sampling(nn.Module):
    def forward(self, z_mean, z_logvar):
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        epsilon = torch.randn_like(z_mean) * 0.5
        return z_mean + torch.exp(0.5 * z_logvar) * epsilon

In [None]:
def KL_Divergence_Loss(z_mean, z_logvar, free_bits=0.1):  # Increase free bits
    kl = -0.5 * (1 + z_logvar - z_mean.pow(2) - torch.exp(z_logvar))
    kl = torch.sum(torch.clamp(kl, min=free_bits), dim=1)  # Ensure nonzero KL
    return kl.mean()

In [None]:
criterion = nn.L1Loss(reduction='mean')

In [None]:
device = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print(f"Using device: {device}")

In [None]:
loss_fn = lpips.LPIPS(net='vgg').to(device)

def perceptual_loss(real, generated):
    return loss_fn(real, generated).mean()

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.sampling_layer = Sampling()

    def forward(self, x, age):
        z_mean, z_logvar = self.encoder(x, age)
        kl_loss = KL_Divergence_Loss(z_mean, z_logvar)
        z_sample = self.sampling_layer(z_mean, z_logvar)
        return kl_loss, self.decoder(z_sample, age)

model = AutoEncoder().to(device)

In [None]:
optimizer = optim.AdamW([
    {"params": [p for n, p in model.named_parameters() if 'age_fc' in n], "lr": 1e-5},
    {"params": [p for n, p in model.named_parameters() if 'age_fc' not in n], "lr": 0.001}
], weight_decay=0.0001)

warmup_epochs = 10
max_lr = 0.0005

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda epoch: min((epoch + 1) / warmup_epochs, max_lr / 0.001)
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1, threshold=0.001)

In [None]:
def debug_hook(module, input, output):
    print(f"{module.__class__.__name__} output: mean={output.mean().item():.4f}, std={output.std().item():.4f}")

if DEBUG_HOOK:
    model.encoder.age_fc.register_forward_hook(debug_hook)
    model.decoder.age_fc.register_forward_hook(debug_hook)

In [None]:
image_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
class ImageDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform
        self.min_age = self.data_frame['age'].min()
        self.max_age = self.data_frame['age'].max()

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_path = self.data_frame.iloc[idx]['filepath']
        age = float(self.data_frame.iloc[idx]['age'])
        gender = self.data_frame.iloc[idx]['gender']
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        normalized_age = (age - self.min_age) / (self.max_age - self.min_age)
        label = torch.tensor([normalized_age, gender], dtype=torch.float32)
        return image, label

In [None]:
train_csv = "Dataset/Index/Train.csv"
val_csv = "Dataset/Index/Validation.csv"
test_csv = "Dataset/Index/Test.csv"

train_dataset = ImageDataset(train_csv, transform=image_transforms)
val_dataset = ImageDataset(val_csv, transform=image_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
sample_batch = next(iter(train_loader))

sample_image, _ = sample_batch

sample_image = sample_image[0].cpu().numpy().transpose(1, 2, 0)

sample_image = (sample_image * 0.5) + 0.5

plt.imshow(sample_image)
plt.axis("off")
plt.show()

In [None]:
if DEBUG_CONSTANT_AGE:
    def get_age(labels, batch_size, device):
        return CONSTANT_AGE.expand(batch_size, 1)
else:
    def get_age(labels, batch_size, device):
        return labels[:, 0].unsqueeze(1).to(device, non_blocking=True)

In [None]:
checkpoint_dir = "../checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, loss, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, filename)

In [None]:
if not LOAD_FROM_CHECKPOINT:
    start_epoch = 0
if LOAD_FROM_CHECKPOINT:
    checkpoint_path = "../previous_checkpoints/checkpoints_CVAE_03_06_morning/checkpoint_epoch_15.pth"

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    start_epoch = checkpoint['epoch'] + 1
    last_loss = checkpoint['loss']

    print(f"Resuming training from Epoch {start_epoch} with last loss {last_loss:.4f}")

In [None]:
CONSTANT_AGE = torch.tensor([[0.5]], dtype=torch.float32).to(device)  # Use constant normalized age

In [None]:
epochs = 30
val_losses = []

In [None]:
for epoch in trange(start_epoch, epochs, desc="Epoch Progress", position=0, leave=True):
    optimizer.zero_grad(set_to_none=True)

    model.train()
    total_loss = 0
    total_kl_loss = 0
    total_rec_loss = 0
    num_batches = len(train_loader)

    freeze_epochs = 18
    if epoch < freeze_epochs:
        for param in model.encoder.age_fc.parameters():
            param.requires_grad = False
        for param in model.decoder.age_fc.parameters():
            param.requires_grad = False
    else:
        for param in model.encoder.age_fc.parameters():
            param.requires_grad = True
        for param in model.decoder.age_fc.parameters():
            param.requires_grad = True

    with tqdm(train_loader, desc=f"Training Epoch {epoch+1}", position=1, leave=False, dynamic_ncols=True) as pbar:
        for batch_idx, (xb, labels) in enumerate(pbar):
            optimizer.zero_grad(set_to_none=True)
            batch_size = xb.size(0)
            age = get_age(labels, batch_size, device)
            xb = xb.to(device, non_blocking=True)

            # Forward pass
            z_mean, z_logvar = model.encoder(xb, age)
            z_sample = model.sampling_layer(z_mean, z_logvar)
            outs = model.decoder(z_sample, age)

            # Compute losses
            rec_loss = criterion(xb, outs) + 0.2 * perceptual_loss(xb, outs)
            kl_loss = KL_Divergence_Loss(z_mean, z_logvar, free_bits=0.01)
            # Use annealed beta (applied consistently in both training and validation)
            current_beta = 0.01 + 0.99 * min((epoch + 1) / warmup_epochs, 1.0)  # Start lower
            loss = rec_loss + current_beta * kl_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()

            total_loss += loss.item()
            total_rec_loss += rec_loss.item()
            total_kl_loss += kl_loss.item()

    avg_train_loss = total_loss / num_batches
    avg_rec_loss = total_rec_loss / num_batches
    avg_kl_loss = total_kl_loss / num_batches

    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        with tqdm(val_loader, desc=f"Validation Epoch {epoch+1}", position=2, leave=False, dynamic_ncols=True) as pbar:
            for xb, labels in pbar:
                batch_size = xb.size(0)
                xb = xb.to(device, non_blocking=True)
                age = get_age(labels, batch_size, device)

                kl_loss, outs = model(xb, age)
                reconstructed_loss = criterion(xb, outs)

                val_loss = reconstructed_loss + current_beta * kl_loss
                total_val_loss += val_loss.item()
                pbar.set_postfix({"Val Loss": val_loss.item()})

    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.4f}, "
          f"Validation Loss: {avg_val_loss:.4f}, Rec Loss: {avg_rec_loss:.4f}, "
          f"KL Loss: {avg_kl_loss:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

    checkpoint_filename = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
    save_checkpoint(model, optimizer, epoch, avg_train_loss, checkpoint_filename)

    warmup_scheduler.step()
    scheduler.step(avg_val_loss)

In [None]:
desired_age = (30.0 - 0) / (80.0 - 0)
desired_age_tensor = torch.tensor([[desired_age]], dtype=torch.float32).to(device)

model.eval()
zsample = torch.randn(1, 200).to(device)
with torch.no_grad():
    gen_img = model.decoder(zsample, desired_age_tensor).cpu().squeeze(0).numpy().transpose(1, 2, 0)
    plt.imshow(gen_img)
    plt.axis("off")
    plt.show()