In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import featureman.gen_data as man
from sklearn.cluster import SpectralClustering
import pickle
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### what to do next
- decoder/dict_atoms -> sae_hidden x mlp_dim
- relu_acts -> sae_hidden
- out -> mlp_dim

- concat(samples X filtered_sae_hidden) x mlp_dim
- einop('s f m -> (s f) m')
- PCA of that stuff

In [2]:
sae_dict = torch.load("sae_model_small_batch_2025-08-07_00-27-27.pth", map_location=device)
sae = man.BatchedSAE_Updated(input_dim=512, n_models=5, width_ratio=4).to(device)
sae.load_state_dict(sae_dict)

model_dict = torch.load("modular_arithmetic_model.pth", map_location=device)
model = man.OneLayerTransformer(p=113, d_model=128, nheads=4).to(device)
model.load_state_dict(model_dict)

<All keys matched successfully>

In [3]:
torch.manual_seed(1337)
# generate combination of all inputs a and b range (113)
a = np.arange(113)
b = np.arange(113)
# generate inputs for the model
inputs = np.array([[a_i, 113, b_i, 114] for a_i in a for b_i in b])
inputs = torch.tensor(inputs).to(device)  # Add batch dimension
print(inputs.shape)
logits, activations = model(inputs, return_activations=True)
activations_data = activations[:, -1, :].detach()
batched_acts = activations_data.unsqueeze(0).repeat(5, 1, 1).to(device)
del model, model_dict
print(batched_acts.shape)

torch.Size([12769, 4])
torch.Size([5, 12769, 512])


In [4]:
import pickle

decoder = sae.W_d[3].detach()
print(decoder.shape)
_, _, feat_acts, _ = sae(batched_acts)
features = feat_acts[3].detach()
with open("sae_clusters_small_batch_15.pkl", "rb") as f:
    clusters = pickle.load(f)
clusters = [c for c in clusters if len(c) > 1]  # Filter out clusters with only one element
clusters = sorted(clusters, key=lambda x: len(x), reverse=True)  # Sort by size
print(features.shape)


torch.Size([2048, 512])
torch.Size([12769, 2048])


In [20]:
# 0 if inputs[:, 0] is even else 1
results_a_first_k = np.cos(2*6*np.pi*(inputs[:, 0].detach().cpu().numpy())/113)
results_a_second_k = np.cos(2*44*np.pi*(inputs[:, 0].detach().cpu().numpy())/113)
results_a_third_k = np.cos(2*54*np.pi*(inputs[:, 0].detach().cpu().numpy())/113)

from sklearn.decomposition import PCA
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

for i in [1,2,3]:
    features_interest = features[:, clusters[i]]
    decoder_interest = decoder[clusters[i], :]
    reconstructions = features_interest @ decoder_interest
    
    pca = PCA()
    output_pca = pca.fit_transform(reconstructions.detach().cpu().numpy())
    
    # Create subplots with 3D scenes
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=('PC0 vs PC1 vs PC2', 'PC0 vs PC1 vs PC2', 'PC0 vs PC1 vs PC2'),
        specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]],
        horizontal_spacing=0.02
    )
    
    # First subplot
    fig.add_trace(
        go.Scatter3d(
            x=output_pca[:, 0],
            y=output_pca[:, 1],
            z=output_pca[:, 2],
            mode='markers',
            marker=dict(
                size=2,
                color=results_a_first_k,
                colorscale='viridis',
                opacity=0.3,
                symbol='x',
                colorbar=dict(title="b values", x=0.25, len=0.8)
            ),
            name='First K'
        ),
        row=1, col=1
    )
    
    # Second subplot
    fig.add_trace(
        go.Scatter3d(
            x=output_pca[:, 0],
            y=output_pca[:, 1],
            z=output_pca[:, 2],
            mode='markers',
            marker=dict(
                size=2,
                color=results_a_second_k,
                colorscale='viridis',
                opacity=0.3,
                symbol='x',
                colorbar=dict(title="b values", x=0.62, len=0.8)
            ),
            name='Second K'
        ),
        row=1, col=2
    )
    
    # Third subplot
    fig.add_trace(
        go.Scatter3d(
            x=output_pca[:, 0],
            y=output_pca[:, 1],
            z=output_pca[:, 2],
            mode='markers',
            marker=dict(
                size=2,
                color=results_a_third_k,
                colorscale='viridis',
                opacity=0.3,
                symbol='x',
                colorbar=dict(title="b values", x=0.99, len=0.8)
            ),
            name='Third K'
        ),
        row=1, col=3
    )
    
    # Update layout
    fig.update_layout(
        title=f"Cluster {i + 1} - Size: {len(clusters[i])}",
        width=1400,
        height=500,
        showlegend=False
    )
    
    # Update 3D scene properties for all subplots
    for col in [1, 2, 3]:
        fig.update_scenes(
            xaxis_title="Principal Component 0",
            yaxis_title="Principal Component 1", 
            zaxis_title="Principal Component 2",
            row=1, col=col
        )
    
    fig.show()