# Eigenfaces

## Setup

In [None]:
import numpy as np
import torch

### Load SD model

In [None]:
# Load Stable Diffusion VAE model
from diffusers import AutoencoderKL

sd_vae = AutoencoderKL.from_pretrained("../models/sd_vae/version_0/huggingface")
sd_vae.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd_vae.to(device)

### Load PCA

See pca.ipynb for details on how to compute PCA.

In [None]:
from sklearn.decomposition import PCA
import pickle

# Load PCA model
with open("../data/ffhq/sd_latents_pca_model.pkl", "rb") as f:
    pca = pickle.load(f)

## Group Comparison

### Load SD latents

In [None]:
# Load precomputed SD latents
latents = torch.load("../data/ffhq/sd_latents.pt", weights_only=False)

# Store latent shape for later
latent_shape = latents.shape[1:]

In [None]:
# Prepare latents for PCA
latents_flat = latents.view(latents.size(0), -1).cpu().numpy()

### Load smile scores

In [None]:
import json

# Load smile scores
smile_scores = json.load(open("../data/ffhq/ffhq_smile_scores.json", "r"))

# Sort by file name
smile_scores = {k: smile_scores[k] for k in sorted(smile_scores.keys())}

# Convert to array
smile_scores = np.array(list(smile_scores.values()))

In [None]:
import matplotlib.pyplot as plt

# plot histogram of smile scores
plt.figure(figsize=(10, 5))
plt.hist(smile_scores, bins=50, color='blue', alpha=0.7)
plt.title("Histogram of Smile Scores")
plt.xlabel("Smile Score")
plt.ylabel("Frequency")
plt.grid()
plt.show()

### Transform SD latents into PCA space

In [None]:
# transform latents using PCA
pca_latents = pca.transform(latents_flat)

In [None]:
# Subset PCA latents for high and low smile scores
pca_latents_high = pca_latents[smile_scores > 3.0]
pca_latents_low = pca_latents[smile_scores < 1.0]

In [None]:
def sample_latent(latents, n_samples=1):
    indices = np.random.choice(latents.shape[0], n_samples, replace=False)
    return latents[indices]

# Sample one latent from each subset
latent_high = sample_latent(pca_latents_high)
latent_low = sample_latent(pca_latents_low)

In [None]:
def prepare_image(image):
    image = image.cpu().permute(1, 2, 0).numpy()
    return (image * 0.5 + 0.5).clip(0, 1)

In [None]:
def get_eigenfaces(pca, n_components=5):
    """Get the first n_components eigenfaces from PCA."""
    eigenfaces_high = []
    eigenfaces_low = []

    for component_idx in range(n_components):
                
        sd_latent_high = pca.mean_ + latent_high[0, component_idx] * pca.components_[component_idx]
        sd_latent_low  = pca.mean_ + latent_low[0, component_idx] * pca.components_[component_idx]

        # Reshape into your VAE shape
        sd_latent_high = sd_latent_high.reshape(1, *latent_shape)
        sd_latent_low  = sd_latent_low.reshape(1, *latent_shape)

        # Decode latents to images
        with torch.no_grad():
            sd_latent_high = torch.tensor(sd_latent_high, dtype=torch.float32).reshape(-1, *latent_shape).to(device)
            sd_latent_low = torch.tensor(sd_latent_low, dtype=torch.float32).reshape(-1, *latent_shape).to(device)

            sd_image_high = sd_vae.decode(sd_latent_high).sample
            sd_image_low = sd_vae.decode(sd_latent_low).sample

        eigenfaces_high.append(prepare_image(sd_image_high[0]))
        eigenfaces_low.append(prepare_image(sd_image_low[0]))

    return eigenfaces_high, eigenfaces_low

In [None]:
n_components = 5

eigenfaces_high, eigenfaces_low = get_eigenfaces(pca, n_components=n_components)

# Plot eigenfaces next to each other
fig, axes = plt.subplots(nrows=n_components, ncols=2, figsize=(6, n_components * 3))

for i in range(n_components):
    axes[i, 0].imshow(eigenfaces_high[i])
    axes[i, 0].axis('off')
    axes[i, 0].set_title(f"High Smile Eigenface {i+1}")

    axes[i, 1].imshow(eigenfaces_low[i])
    axes[i, 1].axis('off')
    axes[i, 1].set_title(f"Low Smile Eigenface {i+1}")
plt.tight_layout()
plt.show()

## Optimization Analysis

### Load LSO Results

In [None]:
results_npz = np.load("../results/gbo_pca_sd_03/opt/iter_0/gbo_opt_res.npz")

z_opt = results_npz["z_opt"]
z_init = results_npz["z_init"]

latent_shape = [16, 32, 32]

z_opt.shape, z_init.shape

In [None]:
# Transform the sd latents into pca space
z_opt_pca = pca.transform(z_opt.reshape(z_opt.shape[0], -1))
z_init_pca = pca.transform(z_init.reshape(z_init.shape[0], -1))

z_opt_pca.shape, z_init_pca.shape

In [None]:
# Determine top 5 components per sample that differ between z_opt and z_init
diff = z_opt_pca - z_init_pca
top_components = np.argsort(np.abs(diff), axis=1)[:, -5:]

top_components

In [None]:
# for each of the 5 samples
#   for each idx in top components
#       get the component value from z_opt_pca
#       get the component value from z_init_pca
#      get the eigenface for that component

eigenfaces_init = [[] for _ in range(z_init_pca.shape[0])]
eigenfaces_opt = [[] for _ in range(z_opt_pca.shape[0])]
for i in range(z_opt_pca.shape[0]):
    for idx in top_components[i]:
        sd_latent_init = pca.mean_ + z_init_pca[i, idx] * pca.components_[idx]
        sd_latent_opt = pca.mean_ + z_opt_pca[i, idx] * pca.components_[idx]

        # Reshape into your VAE shape
        sd_latent_init = sd_latent_init.reshape(1, *latent_shape)
        sd_latent_opt = sd_latent_opt.reshape(1, *latent_shape)

        # Decode latents to images
        with torch.no_grad():
            sd_latent_init = torch.tensor(sd_latent_init, dtype=torch.float32).reshape(-1, *latent_shape).to(device)
            sd_latent_opt = torch.tensor(sd_latent_opt, dtype=torch.float32).reshape(-1, *latent_shape).to(device)

            sd_image_init = sd_vae.decode(sd_latent_init).sample
            sd_image_opt = sd_vae.decode(sd_latent_opt).sample

        eigenfaces_init[i].append(prepare_image(sd_image_init[0]))
        eigenfaces_opt[i].append(prepare_image(sd_image_opt[0]))

In [None]:
top_components

In [None]:
# Plot eigenfaces next to each other
fig, axes = plt.subplots(
    nrows=len(eigenfaces_opt),
    ncols=len(eigenfaces_opt[0]) * 2,
    figsize=(len(eigenfaces_opt[0]) * 5, len(eigenfaces_opt) * 3.2),
    squeeze=False
)

for i in range(len(eigenfaces_opt)):
    for j in range(len(eigenfaces_opt[0])):
        axes[i, j * 2].imshow(eigenfaces_init[i][j])
        axes[i, j * 2].axis('off')
        if j == 0:
            axes[i, j * 2].set_title(f"Image {i+1}\nComponent {top_components[i][j]}\nOriginal")
        else:
            axes[i, j * 2].set_title(f"Component {top_components[i][j]}\nOriginal")

        axes[i, j * 2 + 1].imshow(eigenfaces_opt[i][j])
        axes[i, j * 2 + 1].axis('off')
        axes[i, j * 2 + 1].set_title(f"Optimized")
plt.tight_layout()
plt.show()