# Analyzing weekdays repr in GPT-2 from Engels et al.

In [None]:
%cd /home/can/feature_zoo/ 
%load_ext autoreload
%autoreload 2

In [None]:
import json

fname = "exp/weekdays/temporal_clusters.json"
with open(fname, "r") as f:
    clusters = json.load(f)["clusters"]

clusters

In [None]:
days_idxs = clusters["days_of_week"]["sae_neuron_indices"]
days_idxs

In [None]:
from src.config import load_config
from src.loading import load_sae

cfg = load_config(overrides=[
    "sae=gpt2_relu",
    "llm=gpt2",
    "data=days"
])

sae = load_sae(cfg)
sae.Ad.shape

In [None]:
W_dec_SD = sae.Ad.T
W_dec_days_sD = W_dec_SD[days_idxs]

# Effective Rank

Should be 2?

In [None]:
import torch as th

U, S, Vh = th.linalg.svd(W_dec_days_sD.float(), full_matrices=False)

In [None]:
Vh.shape

In [None]:
energy = S**2
energy_norm = energy / energy.sum()

In [None]:
import matplotlib.pyplot as plt
plt.plot(energy_norm.detach().cpu().numpy())
plt.grid()
plt.ylim((0,1))

# Load days of the week data

In [None]:
from src.cache_llm import load_labeled_acts

return_dict = load_labeled_acts(cfg, force_recompute=True)
llm_BD = return_dict["llm_BD"]
llm_BD.shape

In [None]:
llm_BC = llm_BD.float() @ Vh.T
llm_BC = llm_BC.detach().cpu().numpy()

In [None]:
sparse_code_BD = return_dict["codes_BD"].gather(days_idxs[None, :], dim=-1)
sparse_code_BD.shape

In [None]:
sparse_code_BD

In [None]:
from sklearn.decomposition import PCA

# PCA transform
num_pca_components = 7
EPS = 1e-8
# llm_normalized_BC = llm_BC / (llm_BC.norm(dim=-1, keepdim=True) + EPS)
# llm_normalized_np = llm_normalized_BC.cpu().detach().float().numpy()

pca = PCA(n_components=num_pca_components)
llm_pca_BD = pca.fit_transform(llm_BC)

print(f"PCA transformed shape: {llm_pca_BD.shape}")
print(f"Explained variance ratio: {pca.explained_variance_ratio_}")

In [None]:
# Create subplots with consecutive PC components
labels = return_dict["labels"]
unique_labels = list(dict.fromkeys(labels))
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}

# Tab10 colors for categorical coloring
tab10_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# Define consecutive PC pairs
pc_pairs = [(i, i+1) for i in range(num_pca_components - 1)]

# Create subplot grid
n_rows = 2
n_cols = 3
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
axes = axes.flatten()

# Plot each consecutive PC pair
for idx, (pc_x, pc_y) in enumerate(pc_pairs):
    if idx >= len(axes):
        break

    ax = axes[idx]

    # Plot each label with different color
    for label in unique_labels:
        mask = [l == label for l in labels]
        label_idx = label_to_idx[label]
        color = tab10_colors[label_idx % len(tab10_colors)]
        ax.scatter(llm_pca_BD[mask, pc_x], llm_pca_BD[mask, pc_y],
                  label=label, color=color, alpha=0.8, s=50)

    # Set labels and title
    ax.set_xlabel(f'PC {pc_x}', fontsize=12)
    ax.set_ylabel(f'PC {pc_y}', fontsize=12)
    ax.set_title(f'PC{pc_x} vs PC{pc_y}', fontsize=13, fontweight='bold')
    ax.grid(True, alpha=0.3)

    # Add legend to first subplot only
    if idx == 0:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)

# Hide unused subplots
for idx in range(len(pc_pairs), len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

In [None]:
from src.loading import load_llm

llm = load_llm(cfg)
llm