In [None]:
# manipulate_latent_space.py
import torch
import numpy as np
import pandas as pd
from torchvision.utils import save_image
from train_vae import VAE, IMG_SIZE, LATENT_DIM, DEVICE  # Import from the training script

# --- Configuration ---
MODEL_PATH = 'models/vae_celeba_final.pth' # Path to your trained model
ATTRIBUTES_FILE = 'data/celeba/list_attr_celeba.csv'
IMAGES_DIR = 'data/celeba/img_align_celeba/img_align_celeba/'
N_IMAGES_FOR_VECTOR = 5000 # Number of images to average for attribute vectors
N_TRANSITIONS = 11 # Number of steps in the generated sequence
MANIPULATION_STRENGTH = 3.0 # How strongly to apply the attribute

# --- Load Model ---
model = VAE().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
print("Model loaded successfully.")

# --- Load Attribute Data ---
attributes = pd.read_csv(ATTRIBUTES_FILE)
# CelebA labels are -1 (absent) and 1 (present). Convert to 0 and 1.
attributes.replace(-1, 0, inplace=True)

# --- Function to get latent vectors for attribute ---
def get_attribute_vectors(attr_name):
    # Get image filenames for positive and negative samples
    positive_files = attributes[attributes[attr_name] == 1]['image_id'].tolist()
    negative_files = attributes[attributes[attr_name] == 0]['image_id'].tolist()
    
    # Simple transform
    transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
    ])

    def get_mean_latent_vector(filenames, n_samples):
        vectors = []
        with torch.no_grad():
            for i in range(min(n_samples, len(filenames))):
                img_path = os.path.join(IMAGES_DIR, filenames[i])
                image = transform(Image.open(img_path)).unsqueeze(0).to(DEVICE)
                mu, _ = model.encode(image)
                vectors.append(mu.cpu().numpy())
        return np.mean(np.array(vectors), axis=0).squeeze()

    print(f"Calculating vector for '{attr_name}'...")
    positive_vec = get_mean_latent_vector(positive_files, N_IMAGES_FOR_VECTOR)
    negative_vec = get_mean_latent_vector(negative_files, N_IMAGES_FOR_VECTOR)
    
    # The attribute vector is the difference between the means
    attribute_vector = positive_vec - negative_vec
    return torch.from_numpy(attribute_vector).float().to(DEVICE)


# --- Generate Image Sequence ---
def generate_transition(start_img_path, attribute_vector, attribute_name):
    transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
    ])
    
    # Get latent vector of the starting image
    start_image = transform(Image.open(start_img_path)).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        start_mu, _ = model.encode(start_image)

    # Generate sequence
    images = []
    alphas = np.linspace(-MANIPULATION_STRENGTH, MANIPULATION_STRENGTH, N_TRANSITIONS)
    
    with torch.no_grad():
        for alpha in alphas:
            # Modify the latent vector
            new_z = start_mu + alpha * attribute_vector
            generated_img = model.decode(new_z)
            images.append(generated_img.cpu())
    
    # Save the grid
    output = torch.cat(images)
    save_image(output, f'results/transition_{attribute_name}.png', nrow=N_TRANSITIONS)
    print(f"Saved image grid for '{attribute_name}' transition.")


# --- Main Execution ---
if __name__ == "__main__":
    from PIL import Image
    import os
    from torchvision import transforms

    # Find a neutral starting image (e.g., not smiling, no eyeglasses)
    neutral_candidates = attributes[(attributes['Smiling'] == 0) & (attributes['Eyeglasses'] == 0) & (attributes['Blond_Hair'] == 0)]
    start_image_file = neutral_candidates.iloc[10]['image_id'] # Pick an arbitrary one
    start_image_path = os.path.join(IMAGES_DIR, start_image_file)

    # 1. Smiling
    smiling_vector = get_attribute_vectors('Smiling')
    generate_transition(start_image_path, smiling_vector, 'Smiling')
    
    # 2. Eye Openness (using 'Eyeglasses' as a strong proxy that affects the eye region)
    eyeglasses_vector = get_attribute_vectors('Eyeglasses')
    generate_transition(start_image_path, eyeglasses_vector, 'Eyeglasses')

    # 3. Hairstyle (using 'Blond_Hair' as an example)
    blond_vector = get_attribute_vectors('Blond_Hair')
    generate_transition(start_image_path, blond_vector, 'Blond_Hair')