In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [2]:
device = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print(f"Using device: {device}")

Using device: mps


In [None]:
from dataset import MorphII_Dataset

prepipeline = transforms.Compose([
    transforms.ToPILImage(),             # Convert NumPy array to PIL Image
    transforms.Resize((64, 64)),           # Resize to model's input dimensions
    transforms.ToTensor(),                 # Convert image to tensor with values in [0,1]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1,1]
])

val_dataset = MorphII_Dataset(csv_file="Dataset/Index/Validation.csv", transform=prepipeline)
test_dataset = MorphII_Dataset(csv_file="Dataset/Index/Test.csv", transform=prepipeline)

BATCH_SIZE = 64
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [11]:
def generate_age_variation(model, image, cond, age_values):
    """
    Given an image and its condition, encode it and then decode it
    with varying age conditions.

    Args:
        model: Trained ConditionalVAE.
        image: A single image tensor (C x H x W).
        cond: Its corresponding condition tensor (age), shape [1].
        age_values: Iterable of new normalized age values.

    Returns:
        List of generated images (tensors).
    """
    model.eval()
    outputs = []
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)

        mu, logvar = model.encoder(image, cond)
        z = reparameterize(mu, logvar)
        for age in age_values:
            new_cond = torch.tensor([[age]], dtype=torch.float32).to(device)
            out = model.decoder(z, new_cond)
            outputs.append(out)
    return outputs

# Load a specific image called "image1" without using torchvision
import os
from PIL import Image
import numpy as np

# Assuming image1 is in the same directory or specify the path
image_path = "image1.jpg"  # Adjust path as needed

# Check if file exists
if not os.path.exists(image_path):
    raise FileNotFoundError(f"Could not find {image_path}. Please ensure the file exists.")

# Load and manually preprocess the image using PIL and numpy
img = Image.open(image_path).convert('RGB')
img = img.resize((64, 64), Image.Resampling.LANCZOS)
img_np = np.array(img) / 255.0  # Normalize to [0,1]

# Convert to PyTorch tensor and normalize to [-1,1]
img_tensor = torch.from_numpy(img_np.transpose(2, 0, 1)).float()  # HWC to CHW format
img_tensor = img_tensor * 2.0 - 1.0  # Convert [0,1] to [-1,1]

# For the condition (age), we'll need to provide a value
sample_cond = torch.tensor([0.5], dtype=torch.float32)  # Default to middle age (0.5)

# Generate age variations
age_range = np.linspace(0.0, 1.0, 10)
generated_images = generate_age_variation(model, img_tensor, sample_cond, age_range)

# Display the results
plt.figure(figsize=(15, 3))
# First show the original image
img_display = (img_tensor.cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
plt.subplot(1, len(generated_images)+1, 1)
plt.imshow(img_display)
plt.title(f"Original (Age: {sample_cond.item():.2f})")
plt.axis("off")

# Then show the age variations
for i, gen in enumerate(generated_images):
    gen_np = (gen.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    plt.subplot(1, len(generated_images)+1, i+2)
    plt.imshow(gen_np)
    plt.title(f"Age: {age_range[i]:.2f}")
    plt.axis("off")
plt.show()

FileNotFoundError: Could not find image1.jpg. Please ensure the file exists.

In [None]:
def generate_age_variation(model, image, cond, age_values):
    """
    Given an image and its condition, encode it and then decode it
    with varying age conditions.

    Args:
        model: Trained ConditionalVAE.
        image: A single image tensor (C x H x W).
        cond: Its corresponding condition tensor (age), shape [1].
        age_values: Iterable of new normalized age values.

    Returns:
        List of generated images (tensors).
    """
    model.eval()
    outputs = []
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        cond = cond.unsqueeze(0).to(device)

        mu, logvar = model.encoder(image, cond)
        z = reparameterize(mu, logvar)
        for age in age_values:
            new_cond = torch.tensor([[age]], dtype=torch.float32).to(device)
            out = model.decoder(z, new_cond)
            outputs.append(out)
    return outputs

# Load a specific image called "image1" without using torchvision
import os
from PIL import Image
import numpy as np

# Assuming image1 is in the same directory or specify the path
image_path = "image1.jpg"  # Adjust path as needed

# Check if file exists
if not os.path.exists(image_path):
    raise FileNotFoundError(f"Could not find {image_path}. Please ensure the file exists.")

# Load and manually preprocess the image using PIL and numpy
img = Image.open(image_path).convert('RGB')
img = img.resize((64, 64), Image.Resampling.LANCZOS)
img_np = np.array(img) / 255.0  # Normalize to [0,1]

# Convert to PyTorch tensor and normalize to [-1,1]
img_tensor = torch.from_numpy(img_np.transpose(2, 0, 1)).float()  # HWC to CHW format
img_tensor = img_tensor * 2.0 - 1.0  # Convert [0,1] to [-1,1]

# For the condition (age), we'll need to provide a value
sample_cond = torch.tensor([0.5], dtype=torch.float32)  # Default to middle age (0.5)

# Generate age variations
age_range = np.linspace(0.0, 1.0, 10)
generated_images = generate_age_variation(model, img_tensor, sample_cond, age_range)

# Display the results
plt.figure(figsize=(15, 3))
# First show the original image
img_display = (img_tensor.cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
plt.subplot(1, len(generated_images)+1, 1)
plt.imshow(img_display)
plt.title(f"Original (Age: {sample_cond.item():.2f})")
plt.axis("off")

# Then show the age variations
for i, gen in enumerate(generated_images):
    gen_np = (gen.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    plt.subplot(1, len(generated_images)+1, i+2)
    plt.imshow(gen_np)
    plt.title(f"Age: {age_range[i]:.2f}")
    plt.axis("off")
plt.show()