In [None]:
# Install Dependencies
!pip install -q transformer-lens datasets tqdm accelerate

In [None]:
# Import Libraries
import torch
from transformer_lens import HookedTransformer
import numpy as np
from IPython.display import HTML, display
import pandas as pd
from typing import List, Tuple

In [None]:
# Get Max Neuron Activation and Visualize

model = HookedTransformer.from_pretrained("gelu-1l", center_unembed=True)
tok = model.tokenizer
model.eval()
LAYER = 0
NEURON = 2029

activations = {}
def save_mlp_post(act, hook):
    activations["mlp_post"] = act.detach().cpu()

model.add_hook(f"blocks.{LAYER}.mlp.hook_post", save_mlp_post)

def get_max_neuron_activation_with_context(text: str, window: int = 3):

    inputs = tok(text, return_tensors="pt")
    _ = model(inputs.input_ids)

    acts = activations["mlp_post"][0, :, NEURON]
    max_val, max_idx = acts.max(dim=0)
    max_idx = max_idx.item()
    max_val = max_val.item()

    seq_len = inputs.input_ids.size(1)
    start = max(0, max_idx - window)
    end   = min(seq_len, max_idx + window + 1)

    context_ids = inputs.input_ids[0, start:end].tolist()
    context_str = tok.decode(context_ids)

    token_id = inputs.input_ids[0, max_idx].item()
    token_str = tok.decode([token_id])

    return {
        "token": token_str,
        "position": max_idx,
        "activation": max_val,
        "context": context_str
    }

def visualize_activations(text: str):

    inputs = tok(text, return_tensors="pt")
    _ = model(inputs.input_ids)

    acts = activations["mlp_post"][0, :, NEURON]

    min_a, max_a = acts.min(), acts.max()
    norm_acts = (acts - min_a) / (max_a - min_a + 1e-10)

    html_chunks = []
    for idx, (token_id, score) in enumerate(zip(inputs.input_ids[0], norm_acts)):
        token_str = tok.decode([token_id.item()])
        color = f"rgba(225,76,76,{score:.2f})"
        html_chunks.append(
            f"<span style='background-color:{color}; padding:2px; border-radius:3px;'>{token_str}</span>"
        )

    html = " ".join(html_chunks)
    display(HTML(f"<div style='line-height:1.6; font-size:1.1em'>{html}</div>"))


In [None]:
# Run example text and visualize

example = '''
The kings and the queens and the princes.
'''
visualize_activations(example)
res = get_max_neuron_activation_with_context(example)
print(f"Max‐activating token: {res['token']} (position {res['position']})")
print(f"Activation value: {res['activation']:.4f}")
print(f"Context window: \"{res['context']}\"")

In [None]:
# Print neuron 2029's top ten most boosted tokens and top ten least boosted tokens

mlp    = model.blocks[LAYER].mlp
W_out  = mlp.W_out.detach().cpu()
W_U    = model.unembed.W_U.detach().cpu()

print("W_out shape:", W_out.shape)
print("W_U shape:  ", W_U.shape)
d_mlp, d_model = W_out.shape

proj_vec = W_out[NEURON]

logit_weights = proj_vec @ W_U

topk_pos = torch.topk(logit_weights, k=10)
topk_neg = torch.topk(logit_weights, k=10, largest=False)

print("\nTop 10 tokens BOOSTED by neuron", NEURON)
for idx, score in zip(topk_pos.indices.tolist(), topk_pos.values.tolist()):
    token = model.tokenizer.decode([idx]).strip()
    print(f"  {token!r}: {score:.2f}")

print("\nTop 10 tokens SUPPRESSED by neuron", NEURON)
for idx, score in zip(topk_neg.indices.tolist(), topk_neg.values.tolist()):
    token = model.tokenizer.decode([idx]).strip()
    print(f"  {token!r}: {score:.2f}")

In [None]:
# Print top ten co-activiating neurons for each feature prompt

def top_coactivating_neurons(text: str, top_k: int = 10):
    toks = tok(text, return_tensors="pt")
    _ = model(toks.input_ids)

    acts = activations["mlp_post"][0]

    peak, _ = acts.max(dim=0)
    vals, idxs = torch.topk(peak, k=top_k)

    print(f"\nTop {top_k} co-activating neurons for “{text[:30]}…”:")
    for i, v in zip(idxs.tolist(), vals.tolist()):
        print(f"  Neuron {i:4d}  peak act = {v:.3f}")

top_coactivating_neurons("the kings and three camels")
top_coactivating_neurons("for i in range(10): print(i)")

In [None]:
# Print the top-k logits at the final position and their tokens with and without neuron 2029 ablated.

test_prompt = "The king and queen are great."
TOP_K       = 10

def reset():
    model.reset_hooks()

def print_topk_logits(logits, k=10):

    last_logits = logits[0, -1]
    topk = torch.topk(last_logits, k)

    print(f"Top {k} logits:")

    for idx, score in zip(topk.indices.tolist(), topk.values.tolist()):
        token = tok.decode([idx]).strip()
        print(f"  {token!r}: {score:.4f}")

    print()

reset()
logits_base, _ = model.run_with_cache(test_prompt)
print("=== Baseline ===")
print_topk_logits(logits_base, TOP_K)

reset()
def ablate_fn(acts, hook):
    acts[..., NEURON] = 0.0
    return acts

model.add_hook(f"blocks.{LAYER}.mlp.hook_post", ablate_fn)
logits_abl, _ = model.run_with_cache(test_prompt)
print("=== Ablated ===")
print_topk_logits(logits_abl, TOP_K)

In [None]:
# Script for identifying and ablating cancellers as well as baseline comparison

DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME    = "gelu-1l"
LAYER         = 0
NEURON_ID     = 2029
PROMPT_A      = "the king and"
PROMPT_B      = "for i in range"
TOP_CANCELERS = 5
TOP_K_PRINT   = 10

model = HookedTransformer.from_pretrained(MODEL_NAME, center_unembed=True).to(DEVICE).eval()
tok   = model.tokenizer

W_out = model.blocks[LAYER].mlp.W_out.detach()
W_U   = model.unembed.W_U.detach()
d_mlp, d_model = W_out.shape

def run_with_cache(prompt: str):
    return model.run_with_cache(
        prompt,
        names_filter=[f"blocks.{LAYER}.mlp.hook_post"],
        device=DEVICE,
    )

def resid_split(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
    _, cache = run_with_cache(prompt)
    acts = cache[f"blocks.{LAYER}.mlp.hook_post"][0][-1]
    v_amb = acts[NEURON_ID] * W_out[NEURON_ID]
    rest  = (acts @ W_out) - v_amb
    return v_amb, rest

def cosine(a, b): return torch.dot(a, b).item() / (a.norm() * b.norm() + 1e-9)

def print_topk_logits(vec_or_logits: torch.Tensor, k=10, note=""):

    if vec_or_logits.shape[0] == W_U.shape[1]:
        logits = vec_or_logits
    else:
        logits = vec_or_logits @ W_U
    topk = torch.topk(logits, k)
    print(f"\nTop {k} tokens for {note}:")
    for idx, score in zip(topk.indices.tolist(), topk.values.tolist()):
        token = tok.decode([idx]).strip()
        print(f"  {token!r:<10} {score:7.2f}")

def ablate_neurons(prompt: str, neuron_ids: List[int]) -> torch.Tensor:
    def hook_fn(acts, hook):
        acts[..., neuron_ids] = 0.0
        return acts
    return model.run_with_hooks(
        prompt,
        return_type="logits",
        fwd_hooks=[(f"blocks.{LAYER}.mlp.hook_post", hook_fn)],
    )

# 1. Identify cancellation direction
vA, restA = resid_split(PROMPT_A)
vB, restB = resid_split(PROMPT_B)

print("\n=== COSINE(v_ambiguous, rest_of_stream) ===")
print(f"Prompt A ({PROMPT_A[:15]}…): {cosine(vA, restA):+.3f}")
print(f"Prompt B ({PROMPT_B[:15]}…): {cosine(vB, restB):+.3f}")

# 2. Identify Cancellers for Prompt A
_, cacheA = run_with_cache(PROMPT_A)
actsA = cacheA[f"blocks.{LAYER}.mlp.hook_post"][0][-1]
contrA = actsA[:, None] * W_out
align_scores = torch.mv(contrA, vA) / (vA.norm() + 1e-9)

cancelers = align_scores.topk(TOP_CANCELERS, largest=False).indices.tolist()
helpers   = align_scores.topk(TOP_CANCELERS, largest=True ).indices.tolist()

print("\nTop canceling neurons (Prompt A):", cancelers)
print("Top reinforcing neurons (Prompt A):", helpers)

# 3. Causal test with baseline and ablated top-k
def show_logit_change(prompt: str, neuron_group: List[int], label: str):
    base_logits = model(prompt)
    ablated     = ablate_neurons(prompt, neuron_group)

    last_base = base_logits[0, -1]
    last_ab   = ablated[0, -1]

    target = "]):"
    tgt_id = tok.encode(target)[0]
    margin_base = last_base[tgt_id] - last_base.max()
    margin_ab   = last_ab[tgt_id]  - last_ab.max()

    print(f"\n→ Patch test on prompt: {prompt!r}")
    print(f"  Target token: '{target}'")
    print(f"  Margin baseline: {margin_base:+.3f}   after ablation ({label}): {margin_ab:+.3f}")

    print_topk_logits(last_base, TOP_K_PRINT, note="baseline")
    print_topk_logits(last_ab,   TOP_K_PRINT, note=f"ablated {label}")

show_logit_change(PROMPT_A, cancelers, "cancelers")
show_logit_change(PROMPT_B, cancelers, "cancelers")

# 4. Display raw vector top tokens (as before)
print_topk_logits(W_out[NEURON_ID], 20, note="raw W_out[2029]")


In [None]:
# Scipt for showing that *small, different set of neighbours cancels (or reinforces) neuron 2029 in two distinct contexts.

DEVICE  = "cuda" if torch.cuda.is_available() else "cpu"
MODEL   = "gelu-1l"
LAYER   = 0
NEURON  = 2029
K       = 5
PROMPT_A = "the king and"
PROMPT_B = "for i in range"

model = HookedTransformer.from_pretrained(MODEL, center_unembed=True).to(DEVICE).eval()
tok   = model.tokenizer

W_out = model.blocks[LAYER].mlp.W_out.detach()
W_U   = model.unembed.W_U.detach()

def resid_and_acts(prompt: str):

    logits, cache = model.run_with_cache(
        prompt,
        names_filter=[f"blocks.{LAYER}.mlp.hook_post"],
        device=DEVICE,
    )
    acts = cache[f"blocks.{LAYER}.mlp.hook_post"][0][-1]
    resid = acts @ W_out
    return acts, resid

def top_cancelers(acts, v_amb, k: int) -> List[int]:

    contr = acts[:, None] * W_out
    scores = torch.mv(contr, v_amb) / (v_amb.norm()+1e-9)
    return scores.topk(k, largest=False).indices.tolist()

def logit_margin(logits, target_id):
    last = logits[0, -1]
    return (last[target_id] - last.max()).item()

def ablate(prompt: str, neurons: List[int]):
    def hook_fn(acts, hook): acts[..., neurons] = 0.0; return acts
    return model.run_with_hooks(
        prompt,
        return_type="logits",
        fwd_hooks=[(f"blocks.{LAYER}.mlp.hook_post", hook_fn)]
    )

# 1.  Get ambiguous vector once (same direction any prompt)
acts_A, _ = resid_and_acts(PROMPT_A)
v_amb = acts_A[NEURON] * W_out[NEURON]

# 2.  Build canceler sets for each prompt
acts_A, _ = resid_and_acts(PROMPT_A)
acts_B, _ = resid_and_acts(PROMPT_B)

cancel_A = top_cancelers(acts_A, v_amb, K)
cancel_B = top_cancelers(acts_B, v_amb, K)


print(f"\nCancelers for Prompt A: {cancel_A}")
print(f"Cancelers for Prompt B: {cancel_B}")
print(f"Union size = {len(set(cancel_A) | set(cancel_B))}  (should be small)")

# 3.  Measure margin changes
def test(prompt:str, cancelers:List[int], label:str):
    target = "]):"
    tgt_id = tok.encode(target)[0]
    base_logits = model(prompt)
    abl_logits  = ablate(prompt, cancelers)
    m_base = logit_margin(base_logits, tgt_id)
    m_abl  = logit_margin(abl_logits,  tgt_id)
    print(f"\nPrompt: {prompt!r}")
    print(f"Using cancelers from {label}")
    print(f"  logit margin baseline  : {m_base:+.3f}")
    print(f"  after ablation ({label}): {m_abl:+.3f}   Δ = {m_abl-m_base:+.3f}")

print("\n─── In‑context ablations ───────────────────────────")
test(PROMPT_A, cancel_A, "the king and")
test(PROMPT_B, cancel_B, "for i in range")

print("\n─── Cross‑context ablations ─────────────────────────")
test(PROMPT_A, cancel_B, "for i in range")
test(PROMPT_B, cancel_A, "the king and")
