In [None]:
import cv2
import numpy as np
import pandas as pd
import os
import torch
import torchvision.transforms as transforms
from PIL import Image
import dlib
import ast
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataset import MorphII_Dataset

In [None]:
device = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print(f"Using device: {device}")

In [None]:
def load_morphii_stats(stats_csv):
    """
    Loads the precomputed Morph-II statistics.
    """
    df = pd.read_csv(stats_csv)
    mean_brightness = float(df["mean_brightness"][0])
    histogram_str = df["histogram"][0].strip().strip('"')
    mean_histogram = np.array(ast.literal_eval(histogram_str))
    return mean_brightness, mean_histogram

def adjust_brightness(image, target_brightness):
    """
    Adjusts the image brightness.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    current_brightness = np.mean(gray)
    factor = target_brightness / current_brightness
    image = np.clip(image * factor, 0, 255).astype(np.uint8)
    return image

def normalize_background(image, background_value=128):
    """
    Normalizes bright backgrounds to a uniform grey.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    mask = gray > 200  # assume bright background
    image[mask] = background_value
    return image

def manual_match_histogram(source, target_hist):
    """
    Matches the histogram of a grayscale source image to a target histogram.
    Implements a simple cumulative distribution mapping.
    """
    # Compute source histogram and CDF
    source_hist, _ = np.histogram(source.flatten(), bins=256, range=(0,256), density=True)
    source_cdf = np.cumsum(source_hist)
    target_cdf = np.cumsum(target_hist)
    # Create mapping from source to target intensities
    mapping = np.interp(source_cdf, target_cdf, np.arange(256))
    matched = mapping[source]
    return matched.astype(np.uint8)

def match_histogram(image, target_hist):
    """
    Matches the histogram of the L channel in LAB space to the target histogram,
    then converts back to BGR to preserve the original color information.
    """
    # Convert from BGR to LAB
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    L, A, B = cv2.split(lab)

    # Apply manual histogram matching on the L channel only.
    matched_L = manual_match_histogram(L, target_hist)

    # Merge the adjusted L channel back with the original A and B channels.
    matched_lab = cv2.merge([matched_L, A, B])
    matched_bgr = cv2.cvtColor(matched_lab, cv2.COLOR_LAB2BGR)
    return matched_bgr

def preprocess_image(image, stats_csv):
    """
    Applies brightness adjustment, background normalization, and histogram matching.
    """
    mean_brightness, mean_histogram = load_morphii_stats(stats_csv)
    image = adjust_brightness(image, mean_brightness)
    image = normalize_background(image)
    image = match_histogram(image, mean_histogram)
    return image

In [None]:
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')

def align_face(image, stats_csv):
    """
    Aligns the face in the input image and applies dataset normalization.
    Expects image in BGR format.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = detector(gray, 1)
    if len(faces) == 0:
        raise ValueError("No face detected in the image.")
    face = faces[0]
    landmarks = predictor(gray, face)
    landmarks = [(p.x, p.y) for p in landmarks.parts()]

    # Compute eye centers.
    left_eye_center = np.mean(landmarks[36:42], axis=0).astype("int")
    right_eye_center = np.mean(landmarks[42:48], axis=0).astype("int")

    # Calculate rotation angle and scaling.
    dY = right_eye_center[1] - left_eye_center[1]
    dX = right_eye_center[0] - left_eye_center[0]
    angle = np.degrees(np.arctan2(dY, dX))
    desired_right_eye_x = 1.0 - 0.35
    dist = np.sqrt(dX**2 + dY**2)
    desired_dist = (desired_right_eye_x - 0.35) * 256
    scale = desired_dist / dist
    eyes_center = ((left_eye_center[0] + right_eye_center[0]) / 2,
                   (left_eye_center[1] + right_eye_center[1]) / 2)

    # Compute affine transform.
    M = cv2.getRotationMatrix2D(eyes_center, angle, scale)
    tX = 256 * 0.5
    tY = 256 * 0.35
    M[0, 2] += (tX - eyes_center[0])
    M[1, 2] += (tY - eyes_center[1])

    aligned_face = cv2.warpAffine(image, M, (256, 256), flags=cv2.INTER_CUBIC)
    processed_face = preprocess_image(aligned_face, stats_csv)
    return processed_face

In [None]:
stats_csv = "morphii_train_stats.csv"
prepipeline = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Lambda(lambda img: cv2.cvtColor(
        align_face(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR), stats_csv),
        cv2.COLOR_BGR2RGB
    )),
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
img_path = "Dataset/Team pics/Kyler.jpeg"
bgr_img = cv2.imread(img_path)
if bgr_img is None:
    raise FileNotFoundError(f"Image not found: {img_path}")

# Convert to RGB as expected by the pipeline.
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
processed_img = prepipeline(rgb_img)

plt.imshow(processed_img.permute(1, 2, 0) * 0.5 + 0.5)
plt.axis("off")
plt.title("Processed (Face-Aligned) Image")
plt.show()

In [None]:
val_dataset = MorphII_Dataset(csv_file="Dataset/Team pics/team.csv", transform=prepipeline)
test_dataset = MorphII_Dataset(csv_file="Dataset/Team pics/team.csv", transform=prepipeline)

BATCH_SIZE = 64
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
# Encoder: maps image and condition -> latent mean and logvar.
class Encoder(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),   # B x 3 x 64 x 64 -> B x 16 x 32 x 32
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # B x 16 x 32 x 32 -> B x 32 x 16 x 16
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # B x 32 x 16 x 16 -> B x 64 x 8 x 8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # B x 64 x 8 x 8 -> B x 128 x 4 x 4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1)
        )

        # 128 x 4 x 4 = 2048
        self.fc_mu = nn.Linear(2048 + condition_dim, latent_dim)
        self.fc_logvar = nn.Linear(2048 + condition_dim, latent_dim)

    def forward(self, x, condition):
        batch_size = x.size(0)
        x = self.conv(x)              # shape: (B, 128, 4, 4)
        x = x.view(batch_size, -1)    # flatten to (B, 2048)
        x = torch.cat([x, condition], dim=1)  # concatenate condition (B, 2048+condition_dim)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class Decoder(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + condition_dim, 2048)   # B x (latent_dim+condition_dim) -> B x 2048
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 128 x 4 x 4 -> B x 64 x 8 x 8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 64 x 8 x 8 -> B x 32 x 16 x 16
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 32 x 16 x 16 -> B x 16 x 32 x 32
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2,
                               padding=1, output_padding=1),    # B x 16 x 32 x 32 -> B x 3 x 64 x 64
            nn.Tanh()  # Output in [-1, 1]
        )

    def forward(self, z, condition):
        x = torch.cat([z, condition], dim=1)  # shape: (B, latent_dim+condition_dim)
        x = self.fc(x)                        # (B, 2048)
        x = x.view(-1, 128, 4, 4)
        x = self.deconv(x)
        return x

class ConditionalVAE(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(ConditionalVAE, self).__init__()
        self.encoder = Encoder(latent_dim, condition_dim)
        self.decoder = Decoder(latent_dim, condition_dim)

    def forward(self, x, condition):
        mu, logvar = self.encoder(x, condition)
        z = reparameterize(mu, logvar)
        recon_x = self.decoder(z, condition)
        return recon_x, mu, logvar

latent_dim = 100
condition_dim = 1  # only using age, we could expand this to gender and race
model = ConditionalVAE(latent_dim, condition_dim).to(device)
print(model)

In [None]:
try:
    model = model.to(memory_format=torch.channels_last)
except Exception as e:
    print("Channels last format not supported:", e)


In [None]:
def load_checkpoint(model, checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    print(f"Loaded checkpoint from {checkpoint_path}")

checkpoint_path = "checkpoints/checkpoint_epoch_500.pth"
load_checkpoint(model, checkpoint_path, device)

In [None]:
model.eval()
with torch.no_grad():
    for i in range(5):
        img, cond = val_dataset[i]
        img = img.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)
        recon, _, _ = model(img, cond)

        orig_np = (img.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        recon_np = (recon.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)

        plt.figure(figsize=(6, 3))
        plt.subplot(1, 2, 1)
        plt.imshow(orig_np)
        plt.title("Original")
        plt.axis("off")
        plt.subplot(1, 2, 2)
        plt.imshow(recon_np)
        plt.title("Reconstructed")
        plt.axis("off")
        plt.show()

In [None]:
def generate_age_variation(model, image, cond, age_values):
    """
    Given an image and its condition, encode it and then decode it
    with varying age conditions.

    Args:
        model: Trained ConditionalVAE.
        image: A single image tensor (C x H x W).
        cond: Its corresponding condition tensor (age), shape [1].
        age_values: Iterable of new normalized age values.

    Returns:
        List of generated images (tensors).
    """
    model.eval()
    outputs = []
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)

        mu, logvar = model.encoder(image, cond)
        z = reparameterize(mu, logvar)
        for age in age_values:
            new_cond = torch.tensor([[age]], dtype=torch.float32).to(device)
            out = model.decoder(z, new_cond)
            outputs.append(out)
    return outputs

sample_img, sample_cond = test_dataset[3]
age_range = np.linspace(0.0, 1.0, 10)
generated_images = generate_age_variation(model, sample_img, sample_cond, age_range)

plt.figure(figsize=(15, 3))
for i, gen in enumerate(generated_images):
    gen_np = (gen.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    plt.subplot(1, len(generated_images), i+1)
    plt.imshow(gen_np)
    plt.title(f"Age: {age_range[i]:.2f}")
    plt.axis("off")
plt.show()