In [1]:
import torch
import json
import numpy as np
import tiktoken
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

from sparse_auto_encoder import SparseAutoencoder
from utils.model import load_GPT_model
from saes__extract_latent_activations import exract_latent_activations
from saes__filter_selective_neurons import find_selective_neurons
from saes__neuron_concept_assoc import calculate_neuron_to_concept_assoc
from saes__top_texts_for_neuron import top_texts_for_neuron
from saes__neuron_concept_mapping import build_neuron_concept_map

In [2]:
device = "cpu"

In [3]:
model = load_GPT_model(path="model_896_14_8_256.pth", device=device)

In [4]:
sae_1 = SparseAutoencoder(input_dim=896, hidden_dim=2688).to(device)
sae_1.load_state_dict(torch.load("sae_models/sae_layer1.pth", map_location=torch.device('cpu')))
sae_1.eval();

sae_2 = SparseAutoencoder(input_dim=896, hidden_dim=2688).to(device)
sae_2.load_state_dict(torch.load("sae_models/sae_layer2.pth", map_location=torch.device('cpu')))
sae_2.eval();

sae_3 = SparseAutoencoder(input_dim=896, hidden_dim=3584).to(device)
sae_3.load_state_dict(torch.load("sae_models/sae_layer3.pth", map_location=torch.device('cpu')))
sae_3.eval();

sae_4 = SparseAutoencoder(input_dim=896, hidden_dim=3584).to(device)
sae_4.load_state_dict(torch.load("sae_models/sae_layer4.pth", map_location=torch.device('cpu')))
sae_4.eval();

sae_5 = SparseAutoencoder(input_dim=896, hidden_dim=3584).to(device)
sae_5.load_state_dict(torch.load("sae_models/sae_layer5.pth", map_location=torch.device('cpu')))
sae_5.eval();

sae_6 = SparseAutoencoder(input_dim=896, hidden_dim=4480).to(device)
sae_6.load_state_dict(torch.load("sae_models/sae_layer6.pth", map_location=torch.device('cpu')))
sae_6.eval();

sae_7 = SparseAutoencoder(input_dim=896, hidden_dim=4480).to(device)
sae_7.load_state_dict(torch.load("sae_models/sae_layer7.pth", map_location=torch.device('cpu')))
sae_7.eval();

sae_8 = SparseAutoencoder(input_dim=896, hidden_dim=4480).to(device)
sae_8.load_state_dict(torch.load("sae_models/sae_layer8.pth", map_location=torch.device('cpu')))
sae_8.eval();

In [None]:
latents_l1 = exract_latent_activations(model, sae_1, layer=1)
latents_l2 = exract_latent_activations(model, sae_2, layer=2)
latents_l3 = exract_latent_activations(model, sae_3, layer=3)
latents_l4 = exract_latent_activations(model, sae_4, layer=4)
latents_l5 = exract_latent_activations(model, sae_5, layer=5)
latents_l6 = exract_latent_activations(model, sae_6, layer=6)
latents_l7 = exract_latent_activations(model, sae_7, layer=7)
latents_l8 = exract_latent_activations(model, sae_8, layer=8)

✅ Saved sae_probing/latent_activations_l1.pt with latents shape torch.Size([665, 2688]) and 665 ids.
✅ Saved sae_probing/latent_activations_l2.pt with latents shape torch.Size([665, 2688]) and 665 ids.
✅ Saved sae_probing/latent_activations_l3.pt with latents shape torch.Size([665, 3584]) and 665 ids.
✅ Saved sae_probing/latent_activations_l4.pt with latents shape torch.Size([665, 3584]) and 665 ids.
✅ Saved sae_probing/latent_activations_l5.pt with latents shape torch.Size([665, 3584]) and 665 ids.
✅ Saved sae_probing/latent_activations_l6.pt with latents shape torch.Size([665, 4480]) and 665 ids.


In [None]:
def map_layer_neurons(layer, activation_threshold=5.0):
    find_selective_neurons(layer=layer, activation_threshold=activation_threshold)
    calculate_neuron_to_concept_assoc(layer=layer, threshold=activation_threshold);

    mappings = build_neuron_concept_map(layer=layer)
    print(f"✅ Done for layer {layer}")
    print('='*10)
    
    return mappings.head()

In [None]:
map_layer_neurons(layer=1)

In [None]:
map_layer_neurons(layer=2)

In [None]:
map_layer_neurons(layer=3)

In [None]:
map_layer_neurons(layer=4)

In [None]:
map_layer_neurons(layer=5)

In [None]:
map_layer_neurons(layer=6)

In [None]:
map_layer_neurons(layer=7)

In [None]:
map_layer_neurons(layer=8)

---

## Plot dual-themed neurons

Plotting function:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx

color_map = {
    "female": "#FAD7AC",
    "male": "#cc9290",
    "marriage": "#ed76b3",
    "love": "#CC0000",
    "wealth": "#007FFF",
    "emotion": "#9467bd",
    "family": "#CC6600",
    "duty": "#a4d9f2",
    "scandal and reputation": "#B3B3B3",
    "society": "#67AB9F",
    "neutral": "#e5ced0",
    "class": "#90ee90"
}

def plot_dual_theme_graph_from_csv(
    csv_path, *, color_map=color_map, ax=None, return_fig=False, seed=62,
    layout="spring", spread=1.8, layer="UNK"
):
    """Reads primary/secondary CSV and plots normalized dual-theme concept graph."""
    df = pd.read_csv(csv_path)
    if "primary_concept" not in df.columns or "secondary_concept" not in df.columns:
        raise ValueError("CSV must contain 'primary_concept' and 'secondary_concept' columns.")
    df = df.dropna(subset=["primary_concept", "secondary_concept"])

    # Build edge list (orderless)
    edges = []
    for _, row in df.iterrows():
        a = str(row["primary_concept"]).strip(" ,").lower()
        b = str(row["secondary_concept"]).strip(" ,").lower()
        if a and b and a != "unk" and b != "unk":
            edges.append((a, b))

    if not edges:
        fig, axp = plt.subplots(figsize=(6, 3))
        axp.text(0.5, 0.5, "No dual-theme data", ha="center", va="center")
        axp.axis("off")
        fig.tight_layout()
        return (fig, axp) if return_fig else None

    G = nx.MultiGraph()
    G.add_edges_from(edges)

    # Create axes if needed
    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 12))
        created_fig = True
    else:
        fig = ax.figure

    # ---- Layout ----
    if layout == "kk":
        pos = nx.kamada_kawai_layout(G)
    else:
        base_k = 0.75 / max(len(G.nodes), 1) ** 0.25
        k = base_k * spread
        pos = nx.spring_layout(G, k=k, seed=seed, iterations=200)

    if spread > 1.0:
        xs, ys = zip(*pos.values())
        cx, cy = sum(xs)/len(xs), sum(ys)/len(ys)
        for n in pos:
            x, y = pos[n]
            pos[n] = (cx + (x - cx) * 1.15, cy + (y - cy) * 1.15)

    unique_nodes = list(G.nodes())
    node_colors = [color_map.get(node, "#cccccc") for node in unique_nodes]

    # ---- Edge normalization ----
    edge_weights = {tuple(sorted((u, v))): G.number_of_edges(u, v) for u, v in G.edges()}
    max_w = max(edge_weights.values())
    # scale edge widths: 1 (minimum) → 10 (maximum)
    norm_edges = {k: (0.5 + 19.5 * (w / max_w)) for k, w in edge_weights.items()}

    # ---- Draw ----
    nx.draw_networkx_nodes(G, pos, node_size=8000, node_color=node_colors, ax=ax)
    nx.draw_networkx_labels(G, pos, font_size=14, font_weight="bold", ax=ax)

    drawn = set()
    for (u, v), width in norm_edges.items():
        if (u, v) in drawn:
            continue
        nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], width=width, alpha=0.6, ax=ax)
        drawn.add((u, v))

    ax.set_title(f"Layer {layer}", fontsize=25)
    ax.axis("off")
    fig.tight_layout()

    if return_fig:
        return fig, ax

---

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l1.csv", layout="kk", layer=1)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l2.csv", layout="kk", layer=2)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l3.csv", layout='kk', layer=3)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l4.csv", layout='kk', layer=4)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l5.csv", spread=100, layer=5)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l6.csv", spread=100, layer=6)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l7.csv", spread=100, layer=7)

In [None]:
plot_dual_theme_graph_from_csv("sae_probing/neuron_concept_primary_secondary_l8.csv", spread=100, layer=8)

In [None]:
import glob, os
import pandas as pd

base_dir = "sae_probing"

In [None]:
pattern = os.path.join(base_dir, "neuron_label_assoc_l*.csv")
files = sorted(glob.glob(pattern))

dfs = []
for f in files:
    df = pd.read_csv(f)
    # make sure each file has a 'layer' column
    if "layer" not in df.columns:
        # try to parse layer number from filename
        layer_num = int(os.path.basename(f).split("_l")[-1].split(".")[0])
        df["layer"] = layer_num
    dfs.append(df)

merged = pd.concat(dfs, ignore_index=True)
out_path = os.path.join(base_dir, "neuron_label_assoc_all_layers.csv")
merged.to_csv(out_path, index=False)

print(f"✅ Merged {len(files)} files into {out_path}, total rows: {len(merged)}")

In [None]:
pattern = os.path.join(base_dir, "neuron_concept_primary_secondary_l*.csv")
files = sorted(glob.glob(pattern))

dfs = []
for f in files:
    df = pd.read_csv(f)
    # make sure each file has a 'layer' column
    if "layer" not in df.columns:
        # try to parse layer number from filename
        layer_num = int(os.path.basename(f).split("_l")[-1].split(".")[0])
        df["layer"] = layer_num
    dfs.append(df)

merged = pd.concat(dfs, ignore_index=True)
out_path = os.path.join(base_dir, "neuron_concept_primary_secondary_all_layers.csv")
merged.to_csv(out_path, index=False)

print(f"✅ Merged {len(files)} files into {out_path}, total rows: {len(merged)}")

In [None]:
all_layers_assoc = pd.read_csv(os.path.join(base_dir, "neuron_label_assoc_all_layers.csv"))
all_layers_assoc = all_layers_assoc[all_layers_assoc["ΔP"] >= 0]
averages = (
    all_layers_assoc
    .drop(columns=["neuron", "concept"])
    .groupby("layer").mean(numeric_only=True)
)

out_path = os.path.join(base_dir, "analysis", "mean_assoc_metrics_all_layers.csv")
averages.to_csv(out_path)

averages.head()

In [None]:
top_neurons = all_layers_assoc[all_layers_assoc["AP"] >= 0.5].sort_values(by='AP', ascending=False)
top_neurons

In [None]:
all_mappings = pd.read_csv(os.path.join(base_dir, "neuron_concept_primary_secondary_all_layers.csv"))
# target (layer, neuron) pairs
targets = list(zip(top_neurons["layer"], top_neurons["neuron"]))
keys = pd.DataFrame(targets, columns=["layer", "neuron"])

filtered = all_mappings.merge(keys, on=["layer", "neuron"], how="inner").sort_values(by='primary_AP', ascending=False)
# rows NOT in the targets:
# dropped = df.merge(keys, on=["layer","neuron"], how="left", indicator=True).query("_merge=='left_only'").drop(columns="_merge")
filtered

In [None]:
!git add .
!git commit -m "new results"
!git push