In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def pool_hidden_states(hidden_states: np.ndarray, pool: str = "mean") -> np.ndarray:
    """
    Pool hidden states over the token dimension.

    Parameters:
        hidden_states (np.ndarray): Array of shape (L, N, T, H)
        pool (str): Pooling method — "mean", "first", or "last"

    Returns:
        np.ndarray: Pooled embeddings of shape (L, N, H)
    """
    if pool == "mean":
        pooled = hidden_states.mean(axis=2)
    elif pool == "first":
        pooled = hidden_states[:, :, 0, :]
    elif pool == "last":
        pooled = hidden_states[:, :, -1, :]
    else:
        raise ValueError(f"Unsupported pool type: {pool}")

    return pooled  # shape: (L, N, H)


metadata = pd.read_csv("outputs/book_of_life_sample_1.csv")
data = np.load("outputs/book_of_life_hidden_states_sample1.npz")
hidden_states = data["hidden_states"]  # shape: (L, N, T, H)
pooled = pool_hidden_states(hidden_states, pool="mean")  # shape: (L, N, H)
pooled = np.nan_to_num(pooled, nan=0.0, posinf=0.0, neginf=0.0)

print("Shape of pooled hidden states:", pooled.shape)
print("NaNs:", np.isnan(pooled).any())
print("Infs:", np.isinf(pooled).any())
print("Max value:", np.max(pooled))
print("Min value:", np.min(pooled))

In [None]:
from sklearn.decomposition import PCA
import numpy as np

def run_pca(X: np.ndarray, n_components: int = 2) -> np.ndarray:
    """
    Run PCA on a (N, H) matrix.

    Parameters:
        X (np.ndarray): Input data of shape (N, H)
        n_components (int): Number of PCA components (default = 2)

    Returns:
        np.ndarray: Transformed data of shape (N, n_components)
    """
    pca = PCA(n_components=n_components)
    return pca.fit_transform(X)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_pca_layers(pooled: np.ndarray, metadata: pd.DataFrame, color_by: str = "age"):
    """
    Plot PCA snapshots for first, middle, and last layers in a clean horizontal grid.
    """
    import matplotlib.cm as cm

    L = pooled.shape[0]
    idxs = [0, L // 2, L - 1]
    titles = ["First Layer", "Middle Layer", "Last Layer"]

    fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)

    norm = plt.Normalize(metadata[color_by].min(), metadata[color_by].max())
    cmap = cm.get_cmap("viridis")

    for ax, idx, title in zip(axes, idxs, titles):
        X = pooled[idx]
        X_pca = run_pca(X)
        sc = ax.scatter(X_pca[:, 0], X_pca[:, 1], c=metadata[color_by], cmap=cmap, norm=norm, alpha=0.7)
        ax.set_title(title)
        ax.set_xlabel("PC1")
        ax.set_ylabel("PC2")
        ax.grid(True)

    # Add single shared colorbar to the right
    fig.colorbar(sc, ax=axes, location="right", shrink=0.8, label=color_by)
    plt.show()

In [None]:
plot_pca_layers(pooled, metadata, color_by="age")  
plot_pca_layers(pooled, metadata, color_by="income") 