# 07 — Neuron-Level Attribution Graphs for Gemma-2-2B

**Purpose**: Generate neuron-level attribution graphs (nodes = raw MLP neurons) for the same prompts
already analyzed with transcoder-level graphs. If the same motif signatures (coherent FFL enrichment,
incoherent FFL depletion) appear at the neuron level, it proves these patterns are fundamental to
how the model routes information — not artifacts of SAE feature decomposition.

**Runtime**: Google Colab T4 GPU. ~30 min per prompt, ~5 hours total for 10 prompts.

**Output**: 10 JSON graphs in `neuron_graphs/{category}/` matching the circuit-tracer format
so the existing motif pipeline works unchanged.

## 1. Setup

In [None]:
# Install dependencies
!pip install -q transformer-lens torch numpy

In [None]:
import json
import time
import shutil
from pathlib import Path

import numpy as np
import torch

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Clone/mount repo for src/ imports (adjust path for your setup)
# Option A: Clone from GitHub
# !git clone https://github.com/YOUR_USER/network-motif-analysis.git /content/repo
# import sys; sys.path.insert(0, '/content/repo')

# Option B: Upload src/neuron_graph.py and src/graph_loader.py directly
# Then add parent to path:
import sys
sys.path.insert(0, '/content')  # Adjust if needed

from src.neuron_graph import (
    NeuronGraphConfig,
    select_top_neurons,
    compute_all_attributions,
    build_neuron_graph_json,
    generate_neuron_graph,
    characterize_graph,
)
from src.graph_loader import load_attribution_graph, graph_summary

## 2. Configuration

In [None]:
config = NeuronGraphConfig(
    model_name="google/gemma-2-2b",
    top_k=100,           # neurons per layer
    max_layer_gap=5,     # max layer distance for edges
    threshold_pct=95,    # keep top 5% of attributions
    device="cuda",
)
print(f"Config: {config}")

## 3. Load Model

In [None]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained(
    "gemma-2-2b",
    device="cuda",
    dtype=torch.float16,
)

print(f"Model: {model.cfg.model_name}")
print(f"Layers: {model.cfg.n_layers}")
print(f"d_mlp: {model.cfg.d_mlp}")
print(f"d_model: {model.cfg.d_model}")
print(f"Vocab size: {model.cfg.d_vocab}")

## 4. Define Target Prompts

10 prompts: 8 matched with existing transcoder graphs + 2 neuron-only.

In [None]:
TARGETS = [
    # Matched with transcoder graphs
    {"slug": "count-by-sevens",     "category": "arithmetic",      "prompt": "7, 14, 21, 28, 35,"},
    {"slug": "five-plus-three",     "category": "arithmetic",      "prompt": "5 + 3 ="},
    {"slug": "capital-france",      "category": "factual_recall",  "prompt": "The capital of France is"},
    {"slug": "opposite-small",      "category": "factual_recall",  "prompt": "The opposite of small is"},
    {"slug": "capital-state-dallas","category": "multihop",        "prompt": "Dallas is a city in the state of Texas. The capital of Texas is"},
    {"slug": "currency-france",     "category": "multihop",        "prompt": "France is a country in Europe. The currency used in France is the"},
    {"slug": "medical-diagnosis",   "category": "reasoning",       "prompt": "A patient presents with fever, stiff neck, and headache. The most likely diagnosis is"},
    {"slug": "sally-school",        "category": "reasoning",       "prompt": "Sally went to school. After school, Sally went to"},
    # Neuron-only (no transcoder match)
    {"slug": "color-sky",           "category": "factual_recall",  "prompt": "The color of the sky is"},
    {"slug": "two-times-seven",     "category": "arithmetic",      "prompt": "2 * 7 ="},
]

print(f"{len(TARGETS)} target prompts:")
for t in TARGETS:
    print(f"  [{t['category']:15s}] {t['slug']:25s} → {t['prompt']!r}")

## 5. Pilot Run — Single Prompt

Test on "5 + 3 =" to verify the pipeline works before committing to the full run.

In [None]:
pilot = TARGETS[1]  # five-plus-three
pilot_dir = Path("neuron_graphs") / "pilot"
pilot_dir.mkdir(parents=True, exist_ok=True)

print(f"Pilot: {pilot['slug']} — {pilot['prompt']!r}")
t0 = time.time()

pilot_path = generate_neuron_graph(
    model=model,
    prompt=pilot["prompt"],
    slug=pilot["slug"],
    category=pilot["category"],
    output_dir=pilot_dir,
    config=config,
)

elapsed = time.time() - t0
print(f"Generated in {elapsed:.0f}s ({elapsed/60:.1f} min)")
print(f"Saved to: {pilot_path}")

In [None]:
# Verify JSON loads through existing pipeline
with open(pilot_path) as f:
    pilot_json = json.load(f)

g = load_attribution_graph(pilot_path)
print(f"\nLoaded graph: {g.vcount()} nodes, {g.ecount()} edges")
print(f"Directed: {g.is_directed()}")
print(f"Feature types: {set(g.vs['feature_type'])}")
print(f"Layers: {sorted(set(g.vs['layer']))}")
print(f"\nGraph summary:")
for k, v in graph_summary(g).items():
    print(f"  {k}: {v}")

## 6. Characterize Pilot Graph

In [None]:
props = characterize_graph(pilot_json)
print("Graph properties:")
for k, v in props.items():
    if k != "nodes_per_layer":
        print(f"  {k:25s}: {v}")
    else:
        print(f"  {k:25s}: {len(v)} layers")

print("\nComparison with typical transcoder graphs:")
print(f"  {'Property':20s} {'Transcoder':>12s} {'Neuron':>12s}")
print(f"  {'Nodes':20s} {'883-1886':>12s} {props['n_nodes']:>12d}")
print(f"  {'Edges':20s} {'29K-91K':>12s} {props['n_edges']:>12d}")
print(f"  {'Density':20s} {'0.02-0.03':>12s} {props['density']:>12.4f}")
print(f"  {'Degree Gini':20s} {'high (hubs)':>12s} {props['degree_gini']:>12.4f}")
print(f"  {'Excitatory %':20s} {'80-90%':>12s} {props['excitatory_fraction']*100:>11.1f}%")

## 7. Null Model Viability Check

Run 10 iterations of layer-pair config null to verify std > 0 for all motifs.

In [None]:
from src.unrolled_census import fast_unrolled_counts
from src.unrolled_motifs import build_catalog, get_effective_layer
from collections import defaultdict

# Layer-pair config null model (inline for Colab)
def _lpc_null(graph, rng):
    """Quick layer-pair config null for viability check."""
    layers = [get_effective_layer(graph, v.index) for v in graph.vs]
    pair_edges = defaultdict(list)
    for e in graph.es:
        key = (layers[e.source], layers[e.target])
        pair_edges[key].append((e.source, e.target, e["sign"], e["weight"], e["raw_weight"]))

    import igraph as ig
    g = ig.Graph(n=graph.vcount(), directed=True)
    for attr in graph.vs.attributes():
        g.vs[attr] = graph.vs[attr]

    all_edges, all_signs, all_weights, all_raw = [], [], [], []
    for (sl, tl), edges in pair_edges.items():
        cur = list(edges)
        np_ = len(cur)
        if np_ < 2:
            for ed in cur:
                all_edges.append((ed[0], ed[1]))
                all_signs.append(ed[2]); all_weights.append(ed[3]); all_raw.append(ed[4])
            continue
        eset = {(e[0], e[1]) for e in cur}
        for _ in range(np_ * 10):
            i1, i2 = rng.choice(np_, size=2, replace=False)
            s1, t1 = cur[i1][0], cur[i1][1]
            s2, t2 = cur[i2][0], cur[i2][1]
            if s1 == s2 or t1 == t2: continue
            if (s1, t2) in eset or (s2, t1) in eset: continue
            eset.discard((s1, t1)); eset.discard((s2, t2))
            eset.add((s1, t2)); eset.add((s2, t1))
            a1, a2 = cur[i1][2:], cur[i2][2:]
            cur[i1] = (s1, t2) + a1; cur[i2] = (s2, t1) + a2
        for ed in cur:
            all_edges.append((ed[0], ed[1]))
            all_signs.append(ed[2]); all_weights.append(ed[3]); all_raw.append(ed[4])

    g.add_edges(all_edges)
    g.es["sign"] = all_signs
    g.es["weight"] = all_weights
    g.es["raw_weight"] = all_raw
    return g

# Run viability check
real_counts = fast_unrolled_counts(g)
print("Real counts:")
for name, count in real_counts.items():
    print(f"  {name:35s}: {count:>8d}")

N_CHECK = 10
null_counts = defaultdict(list)
for i in range(N_CHECK):
    rng = np.random.default_rng(seed=i)
    g_null = _lpc_null(g, rng)
    nc = fast_unrolled_counts(g_null)
    for name, count in nc.items():
        null_counts[name].append(count)

print(f"\nNull model viability ({N_CHECK} iterations):")
print(f"  {'Motif':35s} {'real':>8s} {'mean_null':>10s} {'std_null':>10s} {'viable?':>8s}")
for name in real_counts:
    arr = np.array(null_counts[name], dtype=float)
    m, s = arr.mean(), arr.std()
    viable = "YES" if s > 1e-10 else "DEGEN"
    print(f"  {name:35s} {real_counts[name]:>8d} {m:>10.1f} {s:>10.2f} {viable:>8s}")

## 8. Full Generation Loop

Generate all 10 neuron graphs. Each takes ~30 min on T4.

In [None]:
base_dir = Path("neuron_graphs")
results_log = []
total_t0 = time.time()

for i, target in enumerate(TARGETS):
    slug = target["slug"]
    category = target["category"]
    prompt = target["prompt"]
    out_dir = base_dir / category

    # Skip if already generated
    out_path = out_dir / f"{slug}.json"
    if out_path.exists():
        print(f"[{i+1}/{len(TARGETS)}] {slug} — already exists, skipping")
        results_log.append({"slug": slug, "status": "skipped"})
        continue

    print(f"\n{'='*60}")
    print(f"[{i+1}/{len(TARGETS)}] {slug} ({category})")
    print(f"  Prompt: {prompt!r}")
    print(f"{'='*60}")

    t0 = time.time()
    try:
        path = generate_neuron_graph(
            model=model,
            prompt=prompt,
            slug=slug,
            category=category,
            output_dir=out_dir,
            config=config,
        )
        elapsed = time.time() - t0

        # Quick validation
        with open(path) as f:
            gj = json.load(f)
        props = characterize_graph(gj)
        g_check = load_attribution_graph(path)

        print(f"  Done in {elapsed:.0f}s ({elapsed/60:.1f} min)")
        print(f"  Nodes: {props['n_nodes']}, Edges: {props['n_edges']}, "
              f"Density: {props['density']:.4f}, Gini: {props['degree_gini']:.3f}")

        results_log.append({
            "slug": slug, "status": "ok", "time_s": elapsed,
            **{k: v for k, v in props.items() if k != "nodes_per_layer"},
        })

    except Exception as e:
        elapsed = time.time() - t0
        print(f"  FAILED after {elapsed:.0f}s: {e}")
        results_log.append({"slug": slug, "status": "error", "error": str(e)})

    # Print progress
    total_elapsed = time.time() - total_t0
    done = i + 1
    remaining = (len(TARGETS) - done) * (total_elapsed / done)
    print(f"  Progress: {done}/{len(TARGETS)}, "
          f"{total_elapsed/60:.0f} min elapsed, ~{remaining/60:.0f} min remaining")

    # Clear GPU cache between prompts
    torch.cuda.empty_cache()

total_elapsed = time.time() - total_t0
print(f"\n\nTotal time: {total_elapsed/60:.1f} min ({total_elapsed/3600:.1f} hours)")
print(f"\nResults summary:")
for r in results_log:
    print(f"  {r['slug']:25s} {r['status']}")

## 9. Download

Zip the generated graphs for download.

In [None]:
# Save generation log
with open(base_dir / "generation_log.json", "w") as f:
    json.dump(results_log, f, indent=2)

# Zip for download
shutil.make_archive("neuron_graphs", "zip", ".", "neuron_graphs")
print(f"Created neuron_graphs.zip")
print(f"Download and extract to data/neuron/gemma-2-2b/ in your local repo.")

# In Colab, this triggers a download:
try:
    from google.colab import files
    files.download("neuron_graphs.zip")
except ImportError:
    print("Not in Colab — download neuron_graphs.zip manually.")