In [None]:
import cv2
import numpy as np
import pandas as pd
import torchvision.transforms as transforms
from PIL import Image
import dlib
import ast
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataset import MorphII_Dataset
from collections import Counter

In [None]:
CHECKPOINT_NUMBER = 500

In [None]:
device = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print(f"Using device: {device}")

In [None]:
prepipeline = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_dataset = MorphII_Dataset(csv_file="Dataset/Index/validation.csv", transform=prepipeline)
test_dataset = MorphII_Dataset(csv_file="Dataset/Index/test.csv", transform=prepipeline)

BATCH_SIZE = 128
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            # 128×128 -> 64×64
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1),
            # 64×64 -> 32×32
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            # 32×32 -> 16×16
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            # 16×16 -> 8×8
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1)
        )
        # 128 x 8 x 8 = 8192 features
        self.fc_mu = nn.Linear(8192 + condition_dim, latent_dim)
        self.fc_logvar = nn.Linear(8192 + condition_dim, latent_dim)

    def forward(self, x, condition):
        batch_size = x.size(0)
        x = self.conv(x)  # shape: (B, 128, 8, 8)
        x = x.view(batch_size, -1)  # flatten to (B, 8192)
        x = torch.cat([x, condition], dim=1)  # shape: (B, 8192+condition_dim)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class Decoder(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + condition_dim, 8192)
        self.deconv = nn.Sequential(
            # Reshape (B, 128, 8, 8) -> upsample to 16×16
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            # 16×16 -> 32×32
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            # 32×32 -> 64×64
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1),
            # 64×64 -> 128×128
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.Tanh()  # Output in [-1, 1]
        )

    def forward(self, z, condition):
        x = torch.cat([z, condition], dim=1)  # shape: (B, latent_dim+condition_dim)
        x = self.fc(x)                        # (B, 8192)
        x = x.view(-1, 128, 8, 8)              # reshape to (B, 128, 8, 8)
        x = self.deconv(x)                    # output: (B, 3, 128, 128)
        return x

class ConditionalVAE(nn.Module):
    def __init__(self, latent_dim=100, condition_dim=1):
        super(ConditionalVAE, self).__init__()
        self.encoder = Encoder(latent_dim, condition_dim)
        self.decoder = Decoder(latent_dim, condition_dim)

    def forward(self, x, condition):
        mu, logvar = self.encoder(x, condition)
        z = reparameterize(mu, logvar)
        recon_x = self.decoder(z, condition)
        return recon_x, mu, logvar

latent_dim = 256
condition_dim = 2
model = ConditionalVAE(latent_dim, condition_dim).to(device)
print(model)

In [None]:
try:
    model = model.to(memory_format=torch.channels_last)
except Exception as e:
    print("Channels last format not supported:", e)

In [None]:
def load_checkpoint(model, checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    print(f"Loaded checkpoint from {checkpoint_path}")

checkpoint_path = f"checkpoints/checkpoint_epoch_{CHECKPOINT_NUMBER}.pth"
load_checkpoint(model, checkpoint_path, device)

In [None]:
num_displayed = 5

fig, axs = plt.subplots(num_displayed, 2, figsize=(6, 3 * num_displayed))

for idx in range(num_displayed):
    sample_img, sample_cond = test_dataset[idx]
    img_tensor = sample_img.unsqueeze(0).to(device)
    cond_tensor = sample_cond.unsqueeze(0).to(device)
    recon, _, _ = model(img_tensor, cond_tensor)

    orig_np = (img_tensor.squeeze().cpu().detach().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    recon_np = (recon.squeeze().cpu().detach().numpy().transpose(1, 2, 0) * 0.5 + 0.5)

    axs[idx, 0].imshow(orig_np)
    axs[idx, 0].set_title("Original")
    axs[idx, 0].axis("off")

    axs[idx, 1].imshow(recon_np)
    axs[idx, 1].set_title("Reconstructed")
    axs[idx, 1].axis("off")

plt.tight_layout()

plt.savefig("figures/test_reconstruction.png")
plt.show()

In [None]:
def generate_age_variation(model, image, cond, age_values):
    """
    Given an image and its condition, encode it and then decode it
    with varying age conditions. The condition now is assumed to contain
    [normalized_age, gender] and only the age is varied.

    Args:
        model: Trained ConditionalVAE.
        image: A single image tensor (C x H x W).
        cond: Its corresponding condition tensor (age, gender), shape [2].
        age_values: Iterable of new normalized age values.

    Returns:
        List of generated images (tensors).
    """
    model.eval()
    outputs = []
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)

        mu, logvar = model.encoder(image, cond)
        z = reparameterize(mu, logvar)

        orig_gender = cond[0, 1].item()
        for age in age_values:
            new_cond = torch.tensor([[age, orig_gender]], dtype=torch.float32, device=device)
            out = model.decoder(z, new_cond)
            outputs.append(out)
    return outputs

def generate_gender_variation(model, image, cond):
    """
    Given an image and its condition, encode it and then decode it
    with varying gender conditions.
    """
    model.eval()
    outputs = []

    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)

        mu, logvar = model.encoder(image, cond)
        z = reparameterize(mu, logvar)

        orig_age = cond[0, 0].item()
        for gender in [0, 1]:
            new_cond = torch.tensor([[orig_age, gender]], dtype=torch.float32, device=device)
            out = model.decoder(z, new_cond)
            outputs.append(out)

    return outputs

In [None]:
age_range = np.linspace(0, 1.5, 10)
normalizer = 1.5

fig, axs = plt.subplots(num_displayed, len(age_range), figsize=(2*len(age_range), 2*num_displayed))

for row in range(num_displayed):
    sample_img, sample_cond = test_dataset[row]
    age_range = np.linspace(0.0, normalizer, 10)
    generated_age_images = generate_age_variation(model, sample_img, sample_cond, age_range)

    plt.figure(figsize=(15, 3))
    for col, gen in enumerate(generated_age_images):
        gen_np = (gen.squeeze().cpu().detach().numpy().transpose(1,2,0) * 0.5 + 0.5)
        axs[row, col].imshow(gen_np)
        axs[row, col].set_title(f"Age: {int(16 + (1/normalizer)*age_range[col]*(80-16))}")
        axs[row, col].axis("off")

plt.savefig("figures/test_age.png")
plt.show()

In [None]:
sample_img, sample_cond = test_dataset[0]
generated_gender_images = generate_gender_variation(model, sample_img, sample_cond)

plt.figure(figsize=(15, 3))
for i, gen in enumerate(generated_gender_images):
    gen_np = (gen.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    plt.subplot(1, len(generated_gender_images), i+1)
    plt.imshow(gen_np)
    plt.title(f"Gender: {[0, 1][i]:.2f}")
    plt.axis("off")
plt.show()

In [None]:
def load_morphii_stats(stats_csv):
    df = pd.read_csv(stats_csv)
    mean_brightness = float(df["mean_brightness"][0])
    histogram_str = df["histogram"][0].strip().strip('"')
    mean_histogram = np.array(ast.literal_eval(histogram_str))
    return mean_brightness, mean_histogram

def adjust_brightness_masked(image, target_brightness, mask, custom_brightness_constant = 1):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    if np.sum(mask)==0:
        return image
    current_brightness = np.mean(gray[mask])
    factor = target_brightness / current_brightness * custom_brightness_constant
    image_adj = image.copy().astype(np.float32)
    image_adj[mask] = image_adj[mask] * factor
    image_adj = np.clip(image_adj, 0, 255).astype(np.uint8)
    return image_adj

def manual_match_histogram(source, target_hist):
    # Standard manual matching on a 1D array.
    hist, _ = np.histogram(source.flatten(), bins=256, range=(0,256), density=True)
    cdf_source = np.cumsum(hist)
    cdf_target = np.cumsum(target_hist)
    mapping = np.interp(cdf_source, cdf_target, np.arange(256))
    matched = mapping[source]
    return matched.astype(np.uint8)

def match_histogram_masked(image, target_hist, mask):
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    L, A, B = cv2.split(lab)
    L_matched = L.copy()
    if np.sum(mask) > 0:
        L_fg = L[mask]
        matched_fg = manual_match_histogram(L_fg, target_hist)
        L_matched[mask] = matched_fg
    matched_lab = cv2.merge([L_matched, A, B])
    matched_bgr = cv2.cvtColor(matched_lab, cv2.COLOR_LAB2BGR)
    return matched_bgr

def preprocess_image_custom_masked(image, stats_csv, mask, custom_brightness_constant = 1):
    mean_brightness, mean_histogram = load_morphii_stats(stats_csv)
    image_adj = adjust_brightness_masked(image, mean_brightness, mask, custom_brightness_constant)
    image_matched = match_histogram_masked(image_adj, mean_histogram, mask)
    return image_matched

def estimate_background_color_top(image, height=10, width=2):
    """
    Estimates the background color by sampling a vertical slice from the top
    left and right corners of the image, and then taking the mode of those pixels.
    """
    constant=0
    h, w, _ = image.shape
    # Grab a 2x10 patch from top left and top right.
    patch_left = image[constant:height+constant, 0:width, :]        # shape (height, width, 3)
    patch_right = image[constant:height+constant, w-width:w, :]       # shape (height, width, 3)

    combined = np.concatenate([patch_left.reshape(-1, 3),
                               patch_right.reshape(-1, 3)], axis=0)
    # Convert each pixel to a tuple and compute the mode.
    pixel_list = [tuple(pixel) for pixel in combined]
    mode_color = Counter(pixel_list).most_common(1)[0][0]
    return np.array(mode_color, dtype=np.uint8)

def remove_background(image, tolerance=30, height=10, width=2):
    """
    Computes a gentle background mask based on the difference from a background
    color estimated from the top corner slices using the mode.
    Pixels that differ from the estimated color by more than 'tolerance'
    (Euclidean distance) are considered foreground.
    """
    bg_color = estimate_background_color_top(image, height, width)
    diff = np.linalg.norm(image.astype(np.float32) - bg_color.astype(np.float32), axis=2)
    mask = diff > tolerance

    mask = mask.astype(np.uint8) * 255
    kernel = np.ones((3, 3), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=6)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=6)
    return mask.astype(bool)

In [None]:
def align_face(image, stats_csv, custom_brightness_constant=1):
    """
    Aligns the face in the input image and applies masked brightness and histogram adjustments
    only on the foreground. Then it replaces the background with a constant grey.
    Expects the input image in BGR format.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = detector(gray, 1)
    if len(faces) == 0:
        raise ValueError("No face detected in the image.")
    face = faces[0]
    landmarks = predictor(gray, face)
    landmarks = [(p.x, p.y) for p in landmarks.parts()]

    left_eye = np.mean(landmarks[36:42], axis=0).astype(int)
    right_eye = np.mean(landmarks[42:48], axis=0).astype(int)
    dY = right_eye[1] - left_eye[1]
    dX = right_eye[0] - left_eye[0]
    angle = np.degrees(np.arctan2(dY, dX))
    desired_right_eye_x = 1.0 - 0.35
    dist = np.sqrt(dX**2 + dY**2)
    desired_dist = (desired_right_eye_x - 0.35) * 256
    scale = desired_dist / dist

    eyes_center = (float((left_eye[0] + right_eye[0]) / 2),
                   float((left_eye[1] + right_eye[1]) / 2))
    M = cv2.getRotationMatrix2D(eyes_center, angle, scale)
    tX, tY = 256 * 0.5, 256 * 0.35
    M[0, 2] += (tX - eyes_center[0])
    M[1, 2] += (tY - eyes_center[1])
    grey_bg = (128, 128, 128)
    aligned_face = cv2.warpAffine(image, M, (256, 256), flags=cv2.INTER_CUBIC,
                                  borderMode=cv2.BORDER_CONSTANT, borderValue=tuple(grey_bg))

    mask = remove_background(aligned_face, tolerance=30)
    processed_fg = preprocess_image_custom_masked(aligned_face, stats_csv, mask, custom_brightness_constant)
    final_img = processed_fg.copy()

    final_img[~mask] = np.array([128, 128, 128], dtype=np.uint8)
    return final_img

In [None]:
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')

stats_csv = "morphii_train_stats.csv"

In [None]:
prepipeline_custom = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Lambda(lambda img: cv2.cvtColor(
        align_face(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR), stats_csv),
        cv2.COLOR_BGR2RGB
    )),
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

In [None]:
img_path = "Dataset/Team pics/Brad.jpeg"
bgr_img = cv2.imread(img_path)
if bgr_img is None:
    raise FileNotFoundError(f"Image not found: {img_path}")

rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
processed_img = prepipeline_custom(rgb_img)

plt.imshow(processed_img.permute(1, 2, 0) * 0.5 + 0.5)
plt.axis("off")
plt.title("Processed (Face-Aligned) Image")
plt.show()

In [None]:
custom_dataset = MorphII_Dataset(csv_file="Dataset/Team pics/team.csv", transform=prepipeline_custom)
custom_loader = DataLoader(custom_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
plt.figure(figsize=(6, 3*5))
for idx in range(5):
    custom_img, custom_cond = custom_dataset[idx]
    img_tensor = custom_img.unsqueeze(0).to(device)
    cond_tensor = custom_cond.unsqueeze(0).to(device)
    recon, _, _ = model(img_tensor, cond_tensor)

    orig_np = (img_tensor.squeeze().cpu().detach().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    recon_np = (recon.squeeze().cpu().detach().numpy().transpose(1, 2, 0) * 0.5 + 0.5)

    plt.subplot(5, 2, idx * 2 + 1)
    plt.imshow(orig_np)
    plt.title("Original")
    plt.axis("off")

    plt.subplot(5, 2, idx * 2 + 2)
    plt.imshow(recon_np)
    plt.title("Reconstructed")
    plt.axis("off")

plt.tight_layout()

plt.savefig("figures/team_reconstruction.png")
plt.show()

In [None]:
num_custom = 5
num_age = 5

age_range = np.linspace(0, normalizer, num_age)

fig, axs = plt.subplots(num_custom, num_age, figsize=(3*num_age, 3*num_custom))

for row in range(num_custom):
    custom_img, custom_cond = custom_dataset[row]
    generated_age_images = generate_age_variation(model, custom_img, custom_cond, age_range)

    for col, gen in enumerate(generated_age_images):
        gen_np = (gen.squeeze().cpu().detach().numpy().transpose(1,2,0) * 0.5 + 0.5)
        axs[row, col].imshow(gen_np)
        axs[row, col].set_title(f"Age: {int(16 + (1/normalizer)*age_range[col]*(80-16))}")
        axs[row, col].axis("off")

plt.tight_layout()

plt.savefig("figures/team_age.png")
plt.show()

In [None]:
import imageio
from IPython.display import Image as IPyImage, display
import io
import numpy as np

name_dictionary = {
    1: "Kyler",
    2: "Brad",
    3: "John",
    4: "Casey",
    5: "Batu"
}

num_frames = 100
gif_filenames = []
for idx in range(num_custom):
    custom_img, custom_cond = custom_dataset[idx]
    age_range_frames = np.linspace(0, 2, num_frames)
    generated_age_images_frames = generate_age_variation(model, custom_img, custom_cond, age_range_frames)

    frames = []
    for gen in generated_age_images_frames:
        gen_np = (gen.squeeze().cpu().detach().numpy().transpose(1,2,0) * 0.5 + 0.5)
        frame = (gen_np * 255).astype(np.uint8)
        frames.append(frame)

    bounce_frames = frames + frames[-2:0:-1]

    gif_filename = f"figures/gifs/age_variation_{name_dictionary[idx+1]}.gif"
    imageio.mimsave(gif_filename, bounce_frames, duration=0.01)
    gif_filenames.append(gif_filename)

print("Generated GIFs:")
for gif_filename in gif_filenames:
    display(IPyImage(filename=gif_filename))