# Feature Importance

This notebooks determines the feature importance of the SD latent features in determining the smile score.

## Setup

In [None]:
import numpy as np
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load latents

In [None]:
# Load precomputed SD latents
latents = torch.load("../data/ffhq/sd_latents.pt", weights_only=False)

# Store latent shape for later
latent_shape = latents.shape[1:]

latents = latents.reshape(latents.shape[0], -1)

### Optionally: Encode further to low-dimensional latents

In [None]:
# Load latent model
from src.models.latent_models import LatentVQVAE2
import yaml

latent_model_path = "../models/latent_vqvae2/version_1_2"

# Load latent model configuration
latent_model_config = yaml.safe_load(
    open(f"{latent_model_path}/hparams.yaml", "r")
)

# Initialize latent model
latent_model = LatentVQVAE2(
    ddconfig=latent_model_config["ddconfig"],
    lossconfig=latent_model_config["lossconfig"],
    ckpt_path=f"{latent_model_path}/checkpoints/last.ckpt",
)
latent_model = latent_model.to(device)
latent_model.eval()

In [None]:
from tqdm import tqdm

# Encode sd latents further into VQ latents
batch_size = 256
with torch.no_grad():
    vq_latents = []
    for i in tqdm(range(0, latents.shape[0], batch_size)):
        batch = latents[i : i + batch_size].to(device)

        # Ensure batch is 4D
        batch = batch.view(batch.shape[0], 16, 32, 32)

        # Encode the batch using the latent model
        latents_b, latents_t, _, _ = latent_model.encode(batch)

        # Flatten the two parts
        latents_b = latents_b.view(latents_b.shape[0], -1)
        latents_t = latents_t.view(latents_t.shape[0], -1)

        # Concatenate the two parts
        batch = torch.cat([latents_b, latents_t], dim=1)

        # Move to CPU and store
        vq_latents.append(batch.cpu())

    vq_latents = torch.cat(vq_latents, dim=0)

latents = vq_latents

## Load smile scores

In [None]:
import json

# Load smile scores
smile_scores = json.load(open("../data/ffhq/ffhq_smile_scores.json", "r"))

# Sort by file name
smile_scores = {k: smile_scores[k] for k in sorted(smile_scores.keys())}

# Convert to tensor
smile_scores = torch.tensor(
    [smile_scores[k] for k in sorted(smile_scores.keys())],
    dtype=torch.float32,
)

smile_scores.shape

## Train Regressor

In [None]:
from src.utils import zero_mean_unit_var_normalization

# Normalize inputs
latents_norm, latents_mean, latents_std = zero_mean_unit_var_normalization(latents)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

# Build DataLoader
X_train = latents_norm
y_train = smile_scores.unsqueeze(1)
ds = TensorDataset(X_train, y_train)
loader = DataLoader(ds, batch_size=256, shuffle=True)

In [None]:
import torch.nn as nn

hidden_dims = [512, 256, 128]  # Hidden layer dimensions

# Define simple MLP
layers = []
D = X_train.shape[1]
prev = D
for h in hidden_dims:
    layers += [nn.Linear(prev, h), nn.ReLU()]
    prev = h
layers += [nn.Linear(prev, 1)]
model = nn.Sequential(*layers).to(device)

In [None]:
from tqdm import tqdm
import pickle

epochs = 1000
lr = 1e-3

# Setup progress bar to track loss
pbar = tqdm(range(epochs), desc="Training Progress", unit="epoch")
def update_pbar(epoch, loss):
    pbar.set_postfix({"loss": loss.item()})
    pbar.update(1)

# Train model
opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
model.train()
for epoch in pbar:
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        opt.step()
    update_pbar(epoch, loss)

# Pickle model
with open("../models/feature_selection/latents_smile_model.pkl", "wb") as f:
    pickle.dump(model, f)

In [None]:
# Load model
import pickle
# with open("../models/feature_selection/sd_latents_smile_model.pkl", "rb") as f:
#     model = pickle.load(f)
with open("../models/feature_selection/latents_smile_model.pkl", "rb") as f:
    model = pickle.load(f)

## Feature Importance

### Gradient-based Feature Importance

In [None]:
# Compute gradients w.r.t. inputs
model.eval()
X_all = torch.tensor(X_train, dtype=torch.float32, device=device, requires_grad=True)
y_pred = model(X_all)
# now backprop a uniform gradient of 1 over all outputs
grad_outputs = torch.ones_like(y_pred)
# Compute ∂y_pred / ∂X_all
grads = torch.autograd.grad(
    outputs=y_pred,
    inputs=X_all,
    grad_outputs=grad_outputs,
    create_graph=False,
    retain_graph=False,
)[0]  # shape [N, D]

# Feature importance = mean absolute gradient across samples
gb_importances = grads.abs().mean(dim=0).cpu().numpy()  # shape (D,)

### Permutation-based Feature Importance

In [None]:
# Get a subset of samples
idx = np.random.RandomState(42).choice(len(X_train), size=5000, replace=False)
X_sub, y_sub = X_train[idx], y_train[idx]

In [None]:
from tqdm import tqdm
import numpy as np
import torch
from torch.nn.functional import mse_loss

def fast_perm_imp(model, X_train, y_train, repeats=1, batch_size=256, device="cuda"):

    X_train = X_train.to(device)
    y_train = y_train.to(device)
    model = model.to(device)

    # Baseline score
    with torch.no_grad():
        y0 = model(X_train)
    base_mse = mse_loss(y0, y_train)

    D = X_train.shape[1]
    imps = torch.zeros(D, dtype=torch.float32, device=device)

    for j in tqdm(range(D)):
        scores = []
        col = X_train[:, j].clone()  # copy column to restore later
        for _ in range(repeats):
            # shuffle column in-place
            perm = torch.randperm(len(X_train), device=device)
            X_train[:, j] = X_train[perm, j]

            # batched predict
            preds = []
            for i in range(0, len(X_train), batch_size):
                xb = X_train[i : i + batch_size]
                with torch.no_grad():
                    preds.append(model(xb))
            preds = torch.cat(preds)

            scores.append(mse_loss(preds, y_train))

            # restore column
            X_train[:, j] = col

        # importance = increase in MSE
        scores = torch.tensor(scores, device=device)
        imps[j] = torch.mean(scores) - base_mse

    imps = imps.cpu().numpy()

    return imps

In [None]:
pb_importances = fast_perm_imp(model, X_sub, y_sub, repeats=5, batch_size=256, device="cuda")

### Analysis and Visualization

In [None]:
# Process feature importance
gb_fi_norm = gb_importances / gb_importances.sum()
pb_fi_norm = pb_importances / pb_importances.sum()

# Sort by feature importance
gb_sorted_indices = np.argsort(pb_fi_norm)[::-1]
pb_sorted_indices = np.argsort(gb_fi_norm)[::-1]

In [None]:
import matplotlib.pyplot as plt

# Compute cumulative feature importance
gb_fi_cum = np.cumsum(gb_fi_norm[gb_sorted_indices])
pb_fi_cum = np.cumsum(pb_fi_norm[pb_sorted_indices])

# Get the cumulative feature importance
plt.figure(figsize=(10, 5))
plt.plot(gb_fi_cum, label="Gradient-Based Importance")
plt.plot(pb_fi_cum, label="Permutation-Based Importance")
plt.title("Cumulative Feature Importance")
plt.xlabel("Feature Index")
plt.ylabel("Cumulative Feature Importance")
plt.grid()
plt.legend()
plt.show()

In [None]:
print(pb_sorted_indices.tolist()[:512])