In [1]:
#autoreload
%load_ext autoreload
%autoreload 2
import torch
import matplotlib.pyplot as plt
import numpy as np
import featureman.gen_data as man
import featureman.utils as utils
from sklearn.cluster import SpectralClustering
import pickle
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
sae_dict = torch.load("sae_model_small_batch_2025-08-07_00-27-27.pth", map_location=device)
sae = man.BatchedSAE_Updated(input_dim=512, n_models=5, width_ratio=4).to(device)
sae.load_state_dict(sae_dict)

model_dict = torch.load("modular_arithmetic_model.pth", map_location=device)
model = man.OneLayerTransformer(p=113, d_model=128, nheads=4).to(device)
model.load_state_dict(model_dict)

<All keys matched successfully>

In [3]:
torch.manual_seed(1337)
# generate combination of all inputs a and b range (113)
a_values = np.arange(113)
b_values = np.arange(113)
# generate inputs for the model
inputs = np.array([[a_i, 113, b_i, 114] for a_i in a_values for b_i in b_values])
inputs = torch.tensor(inputs).to(device)  # Add batch dimension

logits, activations = model(inputs, return_activations=True)
activation_final = activations[:, -1, :].detach()
batched_acts = activation_final.unsqueeze(0).repeat(5, 1, 1).to(device)

In [4]:
import pickle

decoder = sae.W_d[3].detach() #2048 x 512

_, _, feat_acts, _ = sae(batched_acts)
features = feat_acts[3].detach() # 12769 x 512

with open("sae_clusters_small_batch_15.pkl", "rb") as f:
    clusters = pickle.load(f)
    clusters = [c for c in clusters if len(c) > 1]  # Filter out clusters with only one element

clusters = sorted(clusters, key=lambda x: len(x), reverse=True)  # Sort by size


In [5]:
mod_additions = inputs[:, 0] + inputs[:, 2]
mod_additions = mod_additions % 113

In [6]:
mod_additions = mod_additions.cpu().numpy()

In [8]:
# Import the module
import featureman.reducibility as irr
import pandas as pd
from datetime import datetime
from sklearn.decomposition import PCA

# Your updated loop
all_summaries = []
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f"irreducibility_results_{timestamp}"

print(f"🔬 Starting irreducibility analysis for {len(clusters)} clusters...")
print(f"📁 Results will be saved to: {save_dir}")

for i in range(len(clusters[:])):
    features_cluster = features[:, clusters[i]]
    decoder_cluster = decoder[clusters[i], :]
    reconstructions = features_cluster @ decoder_cluster
    reconstructions = reconstructions.detach().cpu().numpy()
    
    pca = PCA(n_components=6).fit(reconstructions)
    
    # Quick variance check
    if pca.explained_variance_ratio_[1] < 0.1 or np.isnan(pca.explained_variance_ratio_[1]):
        print(f"Cluster {i+1:3d}: Size {len(clusters[i]):3d} - Skipped (low PC2 variance)")
        continue
    
    # Run silent analysis
    summary = irr.analyze_cluster_irreducibility_silent(
        reconstructions, 
        cluster_idx=i+1, 
        save_plots=True, 
        save_dir=save_dir
    )
    
    all_summaries.append(summary)
    
    # Concise output
    irreducible_flag = "✅" if summary['is_irreducible'] else "❌"
    print(f"Cluster {i+1:3d}: Size {len(clusters[i]):3d} - S={summary['mean_separability']:.3f}, M={summary['mean_mixture']:.3f} {irreducible_flag}")

# Save final summary
print(f"\n{'='*60}")
df, csv_path = irr.save_analysis_summary(all_summaries, save_dir)
print(f"🎯 Analysis complete! Check {save_dir}/ for all plots and {csv_path} for summary.")

🔬 Starting irreducibility analysis for 15 clusters...
📁 Results will be saved to: irreducibility_results_20250815_220430
Cluster   1: Size 198 - Skipped (low PC2 variance)
Cluster   2: Size 179 - S=0.642, M=0.522 ❌
Cluster   3: Size 152 - S=0.011, M=0.289 ❌
Cluster   4: Size 151 - S=0.632, M=0.487 ❌
Cluster   5: Size 150 - S=0.668, M=0.565 ❌
Cluster   6: Size 141 - S=0.496, M=0.518 ❌
Cluster   7: Size 139 - S=0.716, M=0.542 ❌
Cluster   8: Size 137 - S=0.630, M=0.588 ❌
Cluster   9: Size 123 - S=0.624, M=0.485 ❌
Cluster  10: Size 121 - S=0.559, M=0.551 ❌
Cluster  11: Size 118 - S=0.501, M=0.725 ❌
Cluster  12: Size 117 - S=0.634, M=0.552 ❌
Cluster  13: Size 116 - S=0.488, M=0.722 ❌
Cluster  14: Size 105 - S=0.009, M=0.289 ❌
Cluster  15: Size 101 - S=0.677, M=0.644 ❌

📊 Summary saved to: irreducibility_results_20250815_220430/irreducibility_summary_20250815_221901.csv
🏆 Top 5 irreducible clusters:
   cluster_idx  irreducibility_score  mean_separability  mean_mixture  \
5            7      