In [None]:
%%writefile bdh_europarl_train_probe.py
# --- START PASTE ---
#!/usr/bin/env python3
# bdh_europarl_train_probe.py
# One-file: download -> train -> probe (monosemantic neuron IDs)

import argparse
import math
import os
import random
import tarfile
import urllib.request
from dataclasses import dataclass
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F


# =========================
#        BDH MODEL
# =========================

@dataclass
class BDHConfig:
    n_layer: int = 6
    n_embd: int = 256
    dropout: float = 0.1
    n_head: int = 4
    mlp_internal_dim_multiplier: int = 128
    vocab_size: int = 256  # byte-level


def get_freqs(n, theta, dtype):
    def quantize(t, q=2):
        return (t / q).floor() * q

    return (
        1.0
        / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n))
        / (2 * math.pi)
    )


class Attention(torch.nn.Module):
    def __init__(self, config: BDHConfig):
        super().__init__()
        nh = config.n_head
        D = config.n_embd
        N = config.mlp_internal_dim_multiplier * D // nh
        self.freqs = torch.nn.Buffer(
            get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N)
        )

    @staticmethod
    def phases_cos_sin(phases):
        phases = (phases % 1) * (2 * math.pi)
        return torch.cos(phases), torch.sin(phases)

    @staticmethod
    def rope(phases, v):
        v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())
        phases_cos, phases_sin = Attention.phases_cos_sin(phases)
        return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype)

    def forward(self, Q, K, V):
        assert self.freqs.dtype == torch.float32
        assert K is Q
        _, _, T, _ = Q.size()

        r_phases = (
            torch.arange(0, T, device=self.freqs.device, dtype=self.freqs.dtype)
            .view(1, 1, -1, 1)
        ) * self.freqs

        QR = self.rope(r_phases, Q)
        KR = QR

        # causal: only past (strictly lower triangle)
        scores = (QR @ KR.mT).tril(diagonal=-1)

        # NOTE: no softmax -> associative accumulation
        return scores @ V


class BDH(nn.Module):
    def __init__(self, config: BDHConfig):
        super().__init__()
        assert config.vocab_size is not None
        self.config = config
        nh = config.n_head
        D = config.n_embd
        N = config.mlp_internal_dim_multiplier * D // nh

        self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02))
        self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
        self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))

        self.attn = Attention(config)
        self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False)
        self.embed = nn.Embedding(config.vocab_size, D)
        self.drop = nn.Dropout(config.dropout)

        self.lm_head = nn.Parameter(torch.zeros((D, config.vocab_size)).normal_(std=0.02))

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, return_sparse: bool = False):
        C = self.config
        B, T = idx.size()
        D = C.n_embd
        nh = C.n_head
        N = D * C.mlp_internal_dim_multiplier // nh

        x = self.embed(idx).unsqueeze(1)  # (B,1,T,D)
        x = self.ln(x)

        sparse_cache = []  # list of x_sparse per layer

        for level in range(C.n_layer):
            x_latent = x @ self.encoder                  # (B,nh,T,N)
            x_sparse = F.relu(x_latent)                  # (B,nh,T,N)

            if return_sparse:
                sparse_cache.append(x_sparse.detach())

            yKV = self.attn(Q=x_sparse, K=x_sparse, V=x) # (B,nh,T,T)@(B,1,T,D)->(B,nh,T,D) or broadcast-ish
            yKV = self.ln(yKV)

            y_latent = yKV @ self.encoder_v              # (B,nh,T,N)
            y_sparse = F.relu(y_latent)                  # (B,nh,T,N)

            xy_sparse = x_sparse * y_sparse              # (B,nh,T,N)
            xy_sparse = self.drop(xy_sparse)

            yMLP = xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder  # (B,1,T,D)
            y = self.ln(yMLP)
            x = self.ln(x + y)

        logits = x.view(B, T, D) @ self.lm_head  # (B,T,256)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        if return_sparse:
            return logits, loss, sparse_cache
        return logits, loss


# =========================
#     EUROPARL PIPELINE
# =========================

EUROPARL_URLS = {
    # Europarl v7 pairs (English with EU languages)
    "de-en.tgz": "https://www.statmt.org/europarl/v7/de-en.tgz",
    "fr-en.tgz": "https://www.statmt.org/europarl/v7/fr-en.tgz",
    # Add more if you want:
    # "es-en.tgz": "https://www.statmt.org/europarl/v7/es-en.tgz",
}

def download(url: str, out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    if os.path.exists(out_path):
        print(f"[download] exists: {out_path}")
        return
    print(f"[download] {url}")
    urllib.request.urlretrieve(url, out_path)
    print(f"[download] saved: {out_path}")

def extract_tgz(tgz_path: str, out_dir: str):
    marker = os.path.join(out_dir, ".extracted_" + os.path.basename(tgz_path))
    if os.path.exists(marker):
        print(f"[extract] already extracted: {tgz_path}")
        return
    print(f"[extract] {tgz_path}")
    with tarfile.open(tgz_path, "r:gz") as tar:
        tar.extractall(out_dir)
    with open(marker, "w") as f:
        f.write("ok\n")

def iter_text_lines(path: str):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith("<"):  # skip tags if present
                continue
            yield line

def build_train_txt(data_dir: str, out_txt: str, max_lines_per_file: int = 200_000, seed: int = 0):
    random.seed(seed)
    files = []
    for root, _, fnames in os.walk(data_dir):
        for fn in fnames:
            # Europarl extracted files usually: europarl-v7.de-en.en / .de etc.
            if "europarl" in fn and fn.endswith((".en", ".de", ".fr")):
                files.append(os.path.join(root, fn))

    if not files:
        raise RuntimeError(f"No europarl text files found under {data_dir}. Did extraction work?")

    random.shuffle(files)
    os.makedirs(os.path.dirname(out_txt), exist_ok=True)

    with open(out_txt, "w", encoding="utf-8") as out:
        for fp in files:
            c = 0
            for ln in iter_text_lines(fp):
                out.write(ln + "\n")
                c += 1
                if c >= max_lines_per_file:
                    break

    print(f"[dataset] wrote: {out_txt}")
    print(f"[dataset] included files: {len(files)}")


# =========================
#      BYTE DATASET
# =========================

class ByteDataset(torch.utils.data.Dataset):
    def __init__(self, train_txt: str, block_size: int, max_bytes: int):
        with open(train_txt, "r", encoding="utf-8", errors="ignore") as f:
            text = f.read()
        b = text.encode("utf-8", errors="ignore")[:max_bytes]
        self.data = torch.tensor(list(b), dtype=torch.long)
        self.block_size = block_size

    def __len__(self):
        return max(0, len(self.data) - self.block_size - 1)

    def __getitem__(self, i):
        x = self.data[i : i + self.block_size]
        y = self.data[i + 1 : i + self.block_size + 1]
        return x, y


# =========================
#      NEURON PROBING
# =========================

def neuron_id(layer: int, head: int, feat: int, nh: int, N: int) -> int:
    # unique integer across all layers/heads/features
    return layer * (nh * N) + head * N + feat

def ids_for_text(text: str) -> torch.Tensor:
    # byte-level ids in [0..255]
    b = text.encode("utf-8", errors="ignore")
    if len(b) == 0:
        b = b" "  # avoid empty
    return torch.tensor([list(b)], dtype=torch.long)  # (1,T)

@torch.no_grad()
def top_neurons_for_input(
    model: BDH,
    text: str,
    topk: int = 200,
    aggregate: str = "mean",   # "mean" over positions is better than just last byte
) -> List[List[Tuple[int, float, int, int, int]]]:
    """
    Returns per-layer list of hits:
      hits[layer] = [(global_neuron_id, activation, layer, head, feat), ...] sorted by activation desc
    """
    cfg = model.config
    nh = cfg.n_head
    D = cfg.n_embd
    N = D * cfg.mlp_internal_dim_multiplier // nh

    x = ids_for_text(text).to(next(model.parameters()).device)
    _, _, sparse_cache = model(x, return_sparse=True)

    hits_per_layer = []
    for layer, x_sparse in enumerate(sparse_cache):
        # x_sparse: (1, nh, T, N)
        a = x_sparse[0]  # (nh, T, N)

        if aggregate == "last":
            vec = a[:, -1, :]              # (nh, N)
        else:
            vec = a.mean(dim=1)            # (nh, N) mean over T

        flat = vec.reshape(-1)             # nh*N
        k = min(topk, flat.numel())
        vals, idxs = torch.topk(flat, k=k)

        layer_hits = []
        for v, ix in zip(vals.tolist(), idxs.tolist()):
            head = ix // N
            feat = ix % N
            gid = neuron_id(layer, head, feat, nh, N)
            layer_hits.append((gid, float(v), layer, head, feat))

        hits_per_layer.append(layer_hits)

    return hits_per_layer

def shared_neurons_across_texts(
    all_hits: List[List[List[Tuple[int, float, int, int, int]]]],
    topk_intersection: int,
) -> List[List[int]]:
    """
    all_hits[word_i][layer] = list of top hits
    Returns shared neuron IDs per layer (intersection across all words).
    """
    n_layers = len(all_hits[0])
    shared = []
    for layer in range(n_layers):
        sets = []
        for wi in range(len(all_hits)):
            ids = set(h[0] for h in all_hits[wi][layer][:topk_intersection])
            sets.append(ids)
        common = set.intersection(*sets) if sets else set()
        shared.append(sorted(common))
    return shared

def decode_neuron(global_id: int, cfg: BDHConfig) -> Tuple[int, int, int]:
    nh = cfg.n_head
    D = cfg.n_embd
    N = D * cfg.mlp_internal_dim_multiplier // nh
    per_layer = nh * N
    layer = global_id // per_layer
    rem = global_id % per_layer
    head = rem // N
    feat = rem % N
    return layer, head, feat


# =========================
#      TRAIN / PROBE
# =========================

def train(args):
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    print(f"[train] device={device}")

    # 1) Download + extract Europarl (optional)
    if args.download:
        os.makedirs(args.data_dir, exist_ok=True)
        for key in args.pairs:
            if key not in EUROPARL_URLS:
                raise ValueError(f"Unknown pair key '{key}'. Available: {list(EUROPARL_URLS.keys())}")
            tgz_path = os.path.join(args.data_dir, key)
            download(EUROPARL_URLS[key], tgz_path)
            extract_tgz(tgz_path, args.data_dir)

    # 2) Build train.txt
    if not os.path.exists(args.train_txt) or args.rebuild_txt:
        build_train_txt(
            data_dir=args.data_dir,
            out_txt=args.train_txt,
            max_lines_per_file=args.max_lines_per_file,
            seed=args.seed,
        )

    # 3) Dataset
    ds = ByteDataset(args.train_txt, block_size=args.block_size, max_bytes=args.max_bytes)
    if len(ds) <= 0:
        raise RuntimeError("Dataset too small. Increase max_bytes / ensure train.txt is non-empty.")
    dl = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

    # 4) Model
    cfg = BDHConfig(
        vocab_size=256,
        n_layer=args.n_layer,
        n_embd=args.n_embd,
        n_head=args.n_head,
        mlp_internal_dim_multiplier=args.mult,
        dropout=args.dropout,
    )
    model = BDH(cfg).to(device)

    if args.resume and os.path.exists(args.ckpt_path):
        model.load_state_dict(torch.load(args.ckpt_path, map_location=device))
        print(f"[train] resumed from {args.ckpt_path}")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    model.train()
    step = 0
    while step < args.steps:
        for x, y in dl:
            x, y = x.to(device), y.to(device)
            _, loss = model(x, y)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            if step % args.log_every == 0:
                print(f"[train] step={step} loss={float(loss):.4f}")

            step += 1
            if step >= args.steps:
                break

    os.makedirs(os.path.dirname(args.ckpt_path), exist_ok=True)
    torch.save(model.state_dict(), args.ckpt_path)
    print(f"[train] saved checkpoint: {args.ckpt_path}")

def probe(args):
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    print(f"[probe] device={device}")

    cfg = BDHConfig(
        vocab_size=256,
        n_layer=args.n_layer,
        n_embd=args.n_embd,
        n_head=args.n_head,
        mlp_internal_dim_multiplier=args.mult,
        dropout=0.0,
    )
    model = BDH(cfg).to(device)
    if not os.path.exists(args.ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {args.ckpt_path}. Train first.")
    model.load_state_dict(torch.load(args.ckpt_path, map_location=device))
    model.eval()

    texts = args.texts
    if len(texts) < 2:
        raise ValueError("Give at least 2 texts/words to find shared neurons.")

    all_hits = []
    for t in texts:
        hits = top_neurons_for_input(model, t, topk=args.topk, aggregate=args.aggregate)
        all_hits.append(hits)

    shared = shared_neurons_across_texts(all_hits, topk_intersection=args.topk_intersection)

    print("\nInputs:")
    for i, t in enumerate(texts):
        print(f"  [{i}] {t}")

    print("\nShared neuron IDs per layer (intersection of top-k):")
    for layer, ids in enumerate(shared):
        show = ids[:args.show]
        print(f"\nLayer {layer}: {len(ids)} shared IDs")
        if not show:
            print("  (none) -> try: train longer, increase --topk / --topk_intersection, or probe with short sentences)")
            continue
        for gid in show:
            L, H, Fidx = decode_neuron(gid, cfg)
            print(f"  neuron_id={gid}  (layer={L}, head={H}, feat={Fidx})")

    if args.print_top:
        # Also print top neurons per input for layer 0..n
        print("\nTop neurons per input (first few layers):")
        for i, t in enumerate(texts):
            print(f"\n=== Input [{i}] {t} ===")
            hits = all_hits[i]
            for layer in range(min(args.layers_print, len(hits))):
                top = hits[layer][:args.show]
                print(f"  Layer {layer}:")
                for gid, val, L, H, Fidx in top:
                    print(f"    neuron_id={gid} act={val:.4f} (head={H}, feat={Fidx})")


def main():
    p = argparse.ArgumentParser()
    sub = p.add_subparsers(dest="cmd", required=True)

    # TRAIN
    pt = sub.add_parser("train")
    pt.add_argument("--data_dir", type=str, default="data_europarl")
    pt.add_argument("--train_txt", type=str, default="data_europarl/train.txt")
    pt.add_argument("--download", action="store_true", help="download+extract Europarl")
    pt.add_argument("--pairs", nargs="+", default=["de-en.tgz", "fr-en.tgz"], help="which Europarl tgz keys")
    pt.add_argument("--rebuild_txt", action="store_true")
    pt.add_argument("--max_lines_per_file", type=int, default=200_000)
    pt.add_argument("--max_bytes", type=int, default=50_000_000)
    pt.add_argument("--block_size", type=int, default=256)
    pt.add_argument("--batch_size", type=int, default=16)
    pt.add_argument("--steps", type=int, default=5000)
    pt.add_argument("--lr", type=float, default=3e-4)
    pt.add_argument("--log_every", type=int, default=100)
    pt.add_argument("--seed", type=int, default=0)
    pt.add_argument("--ckpt_path", type=str, default="checkpoints/bdh_europarl_bytes.pt")
    pt.add_argument("--resume", action="store_true")
    pt.add_argument("--cpu", action="store_true")

    # model hyperparams
    pt.add_argument("--n_layer", type=int, default=6)
    pt.add_argument("--n_embd", type=int, default=256)
    pt.add_argument("--n_head", type=int, default=4)
    pt.add_argument("--mult", type=int, default=128)
    pt.add_argument("--dropout", type=float, default=0.1)

    # PROBE
    ppb = sub.add_parser("probe")
    ppb.add_argument("--texts", nargs="+", required=True, help="words/sentences you provide at test time")
    ppb.add_argument("--ckpt_path", type=str, default="checkpoints/bdh_europarl_bytes.pt")
    ppb.add_argument("--topk", type=int, default=300, help="top neurons to compute per layer per input")
    ppb.add_argument("--topk_intersection", type=int, default=200, help="intersection over this many top neurons")
    ppb.add_argument("--show", type=int, default=30, help="how many shared IDs to print per layer")
    ppb.add_argument("--aggregate", choices=["mean", "last"], default="mean", help="aggregate over bytes")
    ppb.add_argument("--print_top", action="store_true", help="also print top neurons per input")
    ppb.add_argument("--layers_print", type=int, default=2)
    ppb.add_argument("--cpu", action="store_true")

    # model hyperparams must match training
    ppb.add_argument("--n_layer", type=int, default=6)
    ppb.add_argument("--n_embd", type=int, default=256)
    ppb.add_argument("--n_head", type=int, default=4)
    ppb.add_argument("--mult", type=int, default=128)

    args = p.parse_args()

    # reproducibility
    torch.manual_seed(1337)
    random.seed(1337)

    if args.cmd == "train":
        train(args)
    elif args.cmd == "probe":
        probe(args)

if __name__ == "__main__":
    main()

# --- END PASTE ---


In [None]:
!wc -l bdh_europarl_train_probe.py
!ls -lh bdh_europarl_train_probe.py


In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
print("cleared")


In [None]:
!python -u bdh_europarl_train_probe.py train --download \
  --steps 200 --log_every 10 \
  --batch_size 2 --block_size 128 \
  --n_layer 4 --n_embd 128 --n_head 4 --mult 16 \
  --max_lines_per_file 20000 --max_bytes 5000000


In [None]:
!python -u bdh_europarl_train_probe.py train \
  --steps 2000 --log_every 50 \
  --batch_size 4 --block_size 128 \
  --n_layer 6 --n_embd 128 --n_head 4 --mult 16


In [None]:
!python -u bdh_europarl_train_probe.py train --resume \
  --steps 12000 --log_every 200 \
  --batch_size 4 --block_size 128 \
  --n_layer 6 --n_embd 128 --n_head 4 --mult 16


In [None]:
# # ============================================
# # MONOSEMANTIC SENTENCE-LEVEL PROBE (LATE LAYERS, FIXED)
# # ============================================

# import torch
# import importlib.util
# from transformers import MarianMTModel, MarianTokenizer

# # ---------- Load BDH code ----------
# spec = importlib.util.spec_from_file_location(
#     "bdhmod", "/content/bdh_europarl_train_probe.py"
# )
# bdhmod = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(bdhmod)

# BDHConfig = bdhmod.BDHConfig
# BDH = bdhmod.BDH
# ids_for_text = bdhmod.ids_for_text
# neuron_id = bdhmod.neuron_id

# device = "cuda" if torch.cuda.is_available() else "cpu"

# # ---------- Load checkpoint ----------
# CKPT = "checkpoints/bdh_europarl_bytes.pt"
# state = torch.load(CKPT, map_location=device)

# nh, D, N_tensor = state["encoder"].shape
# mult = int((N_tensor * nh) // D)

# cfg = BDHConfig(
#     vocab_size=256,
#     n_layer=6,
#     n_embd=D,
#     n_head=nh,
#     mlp_internal_dim_multiplier=mult,
#     dropout=0.0,
# )

# model = BDH(cfg).to(device)
# model.load_state_dict(state, strict=True)
# model.eval()

# # Derived N per head for this config (must match model)
# N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head

# print("BDH loaded")
# print("Model config:", {"n_layer": cfg.n_layer, "n_embd": cfg.n_embd, "n_head": cfg.n_head, "mult": cfg.mlp_internal_dim_multiplier, "N_per_head": N})

# # ---------- Translation ----------
# ALL_LANGS = {
#     "German": "Helsinki-NLP/opus-mt-en-de",
#     "French": "Helsinki-NLP/opus-mt-en-fr",
#     "Spanish": "Helsinki-NLP/opus-mt-en-es",
#     "Italian": "Helsinki-NLP/opus-mt-en-it",
# }

# _trans = {}

# def translate(lang, text):
#     if lang not in _trans:
#         tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
#         mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
#         _trans[lang] = (tok, mdl)
#     tok, mdl = _trans[lang]
#     batch = tok([text], return_tensors="pt", padding=True).to(device)
#     out = mdl.generate(**batch, max_new_tokens=64)
#     return tok.decode(out[0], skip_special_tokens=True)

# # ---------- Neuron activation (FIXED: layer decoding) ----------
# TOPK = 5

# @torch.no_grad()
# def topk_neurons_for_word(word, layer_idx, k=5):
#     """
#     Returns TOP-K neurons for a given word at a specific layer.
#     """
#     x = ids_for_text(word).to(device)
#     _, _, sparse = model(x, return_sparse=True)

#     a = sparse[layer_idx][0].mean(dim=1)   # (nh, N)
#     flat = a.reshape(-1)                   # (nh*N)

#     vals, idxs = torch.topk(flat, k)

#     hits = []
#     for v, ix in zip(vals.tolist(), idxs.tolist()):
#         head = ix // N
#         feat = ix % N
#         gid = neuron_id(layer_idx, head, feat, cfg.n_head, N)
#         hits.append((layer_idx, head, feat, gid, v))

#     return hits


# # ============================================
# # PHASE 2: MONOSEMANTIC EXPERIMENT (BEST VERSION)
# # - Demo: Top-K neuron IDs for EN vs translations
# # - Dataset mode: enter many concept-words until END
# # - Metrics: Jaccard overlap + Selectivity + Entropy
# # - Plots: Jaccard distribution + Selectivity bars
# # ============================================

# import re
# import numpy as np
# import matplotlib.pyplot as plt
# import torch

# # ---------- Settings ----------
# LAYERS = [4, 5]                      # late layers
# TOPK_PRINT = 5                       # show top-5 neuron IDs (demo)
# TOPK_SET = 50                        # for overlap / intersection (more stable)
# MAX_CAND = 400                       # cap number of candidate neurons to score (speed)
# EPS = 1e-9

# LAYERS = [L for L in LAYERS if 0 <= L < cfg.n_layer]

# # ---------- Helpers ----------
# def tokenize_words(s: str):
#     # English-ish + accented letters (for translations)
#     return re.findall(r"[A-Za-zÀ-ÿ]+", s.lower())

# @torch.no_grad()
# def topk_set_for_text(text: str, layer_idx: int, k: int):
#     """Top-K neuron IDs for FULL text at a layer."""
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)       # (nh, N)
#     flat = a.reshape(-1)                       # (nh*N,)
#     k = min(k, flat.numel())
#     _, idxs = torch.topk(flat, k)
#     S = set()
#     for ix in idxs.tolist():
#         head = ix // N
#         feat = ix % N
#         S.add(neuron_id(layer_idx, head, feat, cfg.n_head, N))
#     return S

# def jaccard(a: set, b: set) -> float:
#     return len(a & b) / (len(a | b) + EPS)

# def decode_gid(gid: int):
#     per_layer = cfg.n_head * N
#     layer = gid // per_layer
#     rem = gid % per_layer
#     head = rem // N
#     feat = rem % N
#     return layer, head, feat

# @torch.no_grad()
# def activation_gid(text: str, layer_idx: int, gid: int) -> float:
#     """Mean activation of ONE neuron on FULL text."""
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return 0.0
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)       # (nh,N)
#     return float(a[head, feat].item())

# def normalized_entropy(vals):
#     v = np.array(vals, dtype=np.float64)
#     v = np.clip(v, 0, None)
#     p = v / (v.sum() + EPS)
#     H = -np.sum(p * np.log(p + 1e-12))
#     return float(H / (np.log(len(p) + 1e-12)))

# def read_list_until_end(prompt, min_n=1):
#     print(prompt)
#     out = []
#     while True:
#         s = input("> ").strip()
#         if s.upper() == "END":
#             break
#         if s:
#             out.append(s)
#     if len(out) < min_n:
#         raise ValueError(f"Need at least {min_n} items.")
#     return out

# # ============================================
# # (1) DEMO: one input -> translation -> Top-K table
# # ============================================

# print("\n=== DEMO MODE (one example) ===")
# sentence = input("Enter an English word/sentence for demo:\n> ").strip()

# print("\nAvailable languages:")
# for k in ALL_LANGS:
#     print(" -", k)

# langs_raw = input("\nChoose languages (comma-separated):\n> ").strip()
# LANGS = [l.strip() for l in langs_raw.split(",") if l.strip() in ALL_LANGS]
# if len(LANGS) == 0:
#     LANGS = ["German", "French", "Spanish", "Italian"]

# translations = {lang: translate(lang, sentence) for lang in LANGS}

# print("\nTranslations:")
# print("EN:", sentence)
# for lang, txt in translations.items():
#     print(f"{lang[:2].upper()}: {txt}")

# # Print Top-K neurons for full text (not per word) — cleaner for proof
# for LAYER in LAYERS:
#     print(f"\nTOP-{TOPK_PRINT} NEURONS (FULL TEXT) — Layer {LAYER}")
#     print(f"{'TEXT':<10} {'NEURON_ID':<10} {'(layer,head,feat)':<18} {'ACT'}")
#     print("-" * 60)

#     def topk_list(text, k=TOPK_PRINT):
#         x = ids_for_text(text).to(device)
#         _, _, sparse = model(x, return_sparse=True)
#         a = sparse[LAYER][0].mean(dim=1)
#         flat = a.reshape(-1)
#         k = min(k, flat.numel())
#         vals, idxs = torch.topk(flat, k)
#         out = []
#         for v, ix in zip(vals.tolist(), idxs.tolist()):
#             head = ix // N
#             feat = ix % N
#             gid = neuron_id(LAYER, head, feat, cfg.n_head, N)
#             out.append((gid, (LAYER, head, feat), float(v)))
#         return out

#     for gid, trip, act in topk_list(sentence):
#         print(f"{'EN':<10} {gid:<10} {str(trip):<18} {act:.3f}")
#     for lang, txt in translations.items():
#         for gid, trip, act in topk_list(txt):
#             print(f"{lang[:2].upper():<10} {gid:<10} {str(trip):<18} {act:.3f}")

# # ============================================
# # (2) DATASET MODE: enter many concept words (no hardcoded list)
# # ============================================

# print("\n\n=== DATASET MODE (monosemantic test) ===")
# concepts = read_list_until_end(
#     "Enter concept WORDS one by one (e.g., dog, bicycle, oxygen). Type END to finish:",
#     min_n=5
# )

# # translate each concept into chosen languages
# concept_trans = []
# for w in concepts:
#     trans = { "EN": w }
#     for lang in LANGS:
#         trans[lang] = translate(lang, w)
#     concept_trans.append(trans)

# # ============================================
# # (3) OVERLAP METRIC: Jaccard EN vs translations for each concept
# # ============================================

# print("\n=== Jaccard overlap per concept (EN vs each translation) ===")
# all_jaccards = {L: [] for L in LAYERS}

# for LAYER in LAYERS:
#     for trans in concept_trans:
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, LAYER, TOPK_SET)

#         js = []
#         for lang in LANGS:
#             tr = trans[lang]
#             S_tr = topk_set_for_text(tr, LAYER, TOPK_SET)
#             js.append(jaccard(S_en, S_tr))

#         all_jaccards[LAYER].append(np.mean(js))

#     print(f"Layer {LAYER}: mean={np.mean(all_jaccards[LAYER]):.3f}  std={np.std(all_jaccards[LAYER]):.3f}")

# # ============================================
# # PROPER VISUALS (readable proof plots)
# # - No histograms
# # - Show concept-by-concept results (readable)
# # - For each concept: show activation bars with REAL text labels
# # - Also show Jaccard actual vs baseline per concept
# # ============================================

# import numpy as np
# import matplotlib.pyplot as plt

# def baseline_jaccard_for_concept(i, layer_idx):
#     """Baseline: compare EN concept i with translations of a different random concept."""
#     idxs = [j for j in range(len(concept_trans)) if j != i]
#     j_idx = int(np.random.choice(idxs))
#     en = concept_trans[i]["EN"]
#     S_en = topk_set_for_text(en, layer_idx, TOPK_SET)

#     base_js = []
#     other = concept_trans[j_idx]
#     for lang in LANGS:
#         S_other = topk_set_for_text(other[lang], layer_idx, TOPK_SET)
#         base_js.append(jaccard(S_en, S_other))
#     return float(np.mean(base_js))

# def best_shared_neuron_for_concept(i, layer_idx):
#     """Pick best shared neuron (intersection across EN+translations) with max (pos-neg)."""
#     trans = concept_trans[i]
#     POS = [trans["EN"]] + [trans[lang] for lang in LANGS]

#     # Intersection of TopK neuron sets across POS texts
#     inter = None
#     for t in POS:
#         S = topk_set_for_text(t, layer_idx, TOPK_SET)
#         inter = S if inter is None else (inter & S)

#     if not inter:
#         return None

#     inter = list(inter)[:MAX_CAND]

#     best = None
#     for gid in inter:
#         pos_mean = float(np.mean([activation_gid(t, layer_idx, gid) for t in POS]))
#         neg_mean = float(np.mean([activation_gid(t, layer_idx, gid) for t in NEG]))
#         sel = pos_mean - neg_mean
#         if (best is None) or (sel > best[1]):
#             best = (gid, sel, pos_mean, neg_mean)

#     return best  # (gid, sel, pos_mean, neg_mean)

# # --------- 1) Show Jaccard table (actual vs baseline) ----------
# np.random.seed(0)

# for LAYER in LAYERS:
#     print(f"\n==============================")
#     print(f"Layer {LAYER}: Jaccard(EN, translations) per concept + baseline")
#     print(f"==============================")
#     print(f"{'CONCEPT':<14} {'ACTUAL':>8} {'BASE':>8} {'MARGIN':>8}")

#     actual_list, base_list = [], []

#     for i, trans in enumerate(concept_trans):
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, LAYER, TOPK_SET)

#         js = []
#         for lang in LANGS:
#             S_tr = topk_set_for_text(trans[lang], LAYER, TOPK_SET)
#             js.append(jaccard(S_en, S_tr))
#         actual = float(np.mean(js))
#         base = baseline_jaccard_for_concept(i, LAYER)
#         margin = actual - base

#         actual_list.append(actual)
#         base_list.append(base)

#         print(f"{en:<14} {actual:>8.3f} {base:>8.3f} {margin:>8.3f}")

#     # Simple readable plot: actual vs baseline lines
#     x = np.arange(len(concept_trans))
#     names = [ct["EN"] for ct in concept_trans]

#     plt.figure(figsize=(12,4))
#     x = np.arange(len(concept_names))
#     plt.bar(x - 0.2, actual, width=0.4, label="Actual (EN vs its translations)")
#     plt.bar(x + 0.2, baseline, width=0.4, label="Baseline (EN vs other concept translations)")
#     plt.xticks(x, concept_names, rotation=45, ha="right")
#     plt.ylim(0,1)
#     plt.title(f"Layer {LAYER}: Actual vs Baseline Jaccard overlap (K={TOPK_SET})")
#     plt.ylabel("Jaccard")
#     plt.legend()
#     plt.tight_layout()
#     plt.show()

# # --------- 2) For each layer: show 2–3 best concepts proof plots ----------
# TOP_SHOW = min(3, len(concept_trans))

# for LAYER in LAYERS:
#     # rank concepts by (actual-baseline)
#     scored = []
#     for i, trans in enumerate(concept_trans):
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, LAYER, TOPK_SET)

#         js = []
#         for lang in LANGS:
#             js.append(jaccard(S_en, topk_set_for_text(trans[lang], LAYER, TOPK_SET)))
#         actual = float(np.mean(js))
#         base = baseline_jaccard_for_concept(i, LAYER)
#         scored.append((i, actual - base, actual, base))

#     scored.sort(key=lambda x: x[1], reverse=True)
#     picks = scored[:TOP_SHOW]

#     print(f"\n==============================")
#     print(f"Layer {LAYER}: Proof plots for top {TOP_SHOW} concepts")
#     print(f"==============================")

#     for (i, margin, actual, base) in picks:
#         trans = concept_trans[i]
#         concept = trans["EN"]
#         best = best_shared_neuron_for_concept(i, LAYER)

#         if best is None:
#             print(f"\nConcept '{concept}': no shared neuron intersection.")
#             continue

#         best_gid, sel, pos_m, neg_m = best
#         POS_texts = [trans["EN"]] + [trans[lang] for lang in LANGS]
#         POS_labels = ["EN"] + [lang[:2].upper() for lang in LANGS]

#         # Prepare labels with short snippets so plot is readable
#         def short(s, n=22):
#             s = s.replace("\n", " ")
#             return s if len(s) <= n else s[:n-3] + "..."

#         NEG_labels = [f"NEG{i+1}" for i in range(len(NEG))]
#         all_labels = POS_labels + NEG_labels
#         all_texts = POS_texts + NEG

#         acts = [activation_gid(t, LAYER, best_gid) for t in all_texts]

#         # Plot
#         plt.figure(figsize=(11, 4))
#         plt.bar(np.arange(len(all_labels)), acts)
#         plt.xticks(np.arange(len(all_labels)),
#                    [f"{lab}\n{short(tx)}" for lab, tx in zip(all_labels, all_texts)],
#                    rotation=0)
#         plt.title(
#             f"Layer {LAYER} | concept='{concept}' | best neuron={best_gid} (layer,head,feat={decode_gid(best_gid)})\n"
#             f"Jaccard actual={actual:.3f} baseline={base:.3f} margin={margin:.3f} | selectivity(pos-neg)={sel:.3f}"
#         )
#         plt.ylabel("Mean activation")
#         plt.tight_layout()
#         plt.show()

#         # Also print a clean table (so evaluator can read exact numbers)
#         print(f"\nConcept: {concept}")
#         print(f"  Jaccard actual={actual:.3f}, baseline={base:.3f}, margin={margin:.3f}")
#         print(f"  Best neuron: {best_gid}  decode={decode_gid(best_gid)}")
#         print(f"  Selectivity(pos-neg)={sel:.4f}  pos_mean={pos_m:.4f}  neg_mean={neg_m:.4f}")
#         print("  Activations:")
#         for lab, tx, av in zip(all_labels, all_texts, acts):
#             print(f"    {lab:<5} act={av:.4f} | {tx}")


In [None]:
# # ============================================================
# # PHASE 2 — BDH MONOSEMANTIC / POLYSEMANTIC PROBE (CLEAN, FINAL)
# # What you get (clear + report-friendly):
# #  (A) DEMO (single concept): translations + Top-5 neuron IDs table (Layer 4/5)
# #  (B) DATASET MODE: user inputs concepts (>=5)
# #      -> Jaccard Actual vs Baseline bar chart (per concept) for each layer
# #      -> Clear table: Concept | Actual | Baseline | Margin
# #  (C) PROOF PLOTS (top 3 concepts by margin):
# #      -> For each chosen concept + each layer:
# #         Find best shared neuron among POS texts and show:
# #         EN/FR/.../NEG bars with proper labels (TEXT shown under tick)
# #
# # NEG texts are ALWAYS asked from user (no hardcoded negatives).
# # Baseline is RANDOMLY selected (seeded) from "other concepts" so it is repeatable.
# #
# # How to "clear memory" / start again:
# #  - Just re-run the cell (variables reset).
# #  - In Colab: Runtime -> Restart runtime (clears model cache + variables).
# #  - Translation models are cached in _trans dict; call clear_translation_cache()
# # ============================================================

# import re
# import numpy as np
# import torch
# import importlib.util
# import matplotlib.pyplot as plt
# from transformers import MarianMTModel, MarianTokenizer

# # -----------------------------
# # 0) Load BDH code
# # -----------------------------
# spec = importlib.util.spec_from_file_location("bdhmod", "/content/bdh_europarl_train_probe.py")
# bdhmod = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(bdhmod)

# BDHConfig = bdhmod.BDHConfig
# BDH = bdhmod.BDH
# ids_for_text = bdhmod.ids_for_text
# neuron_id = bdhmod.neuron_id

# device = "cuda" if torch.cuda.is_available() else "cpu"

# # -----------------------------
# # 1) Load checkpoint
# # -----------------------------
# CKPT = "checkpoints/bdh_europarl_bytes.pt"
# state = torch.load(CKPT, map_location=device)

# nh, D, N_tensor = state["encoder"].shape
# mult = int((N_tensor * nh) // D)

# cfg = BDHConfig(
#     vocab_size=256,
#     n_layer=6,
#     n_embd=D,
#     n_head=nh,
#     mlp_internal_dim_multiplier=mult,
#     dropout=0.0,
# )

# model = BDH(cfg).to(device)
# model.load_state_dict(state, strict=True)
# model.eval()

# # Derived N per head (must match model)
# N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head

# print("BDH loaded")
# print("Model config:", {"n_layer": cfg.n_layer, "n_embd": cfg.n_embd, "n_head": cfg.n_head,
#                       "mult": cfg.mlp_internal_dim_multiplier, "N_per_head": N})

# # -----------------------------
# # 2) Translation
# # -----------------------------
# ALL_LANGS = {
#     "German":  "Helsinki-NLP/opus-mt-en-de",
#     "French":  "Helsinki-NLP/opus-mt-en-fr",
#     "Spanish": "Helsinki-NLP/opus-mt-en-es",
#     "Italian": "Helsinki-NLP/opus-mt-en-it",
# }

# _trans = {}  # caches tokenizers/models per language

# def clear_translation_cache():
#     """Call this if you want to free memory from translator models."""
#     global _trans
#     _trans = {}
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#     print("Translation cache cleared.")

# def translate(lang: str, text: str) -> str:
#     if lang not in _trans:
#         tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
#         mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
#         _trans[lang] = (tok, mdl)
#     tok, mdl = _trans[lang]
#     batch = tok([text], return_tensors="pt", padding=True).to(device)
#     out = mdl.generate(**batch, max_new_tokens=64)
#     return tok.decode(out[0], skip_special_tokens=True)

# # -----------------------------
# # 3) Helpers (neurons / activations / sets)
# # -----------------------------
# EPS = 1e-9

# def decode_gid(gid: int):
#     """gid -> (layer, head, feat)"""
#     per_layer = cfg.n_head * N
#     layer = gid // per_layer
#     rem = gid % per_layer
#     head = rem // N
#     feat = rem % N
#     return layer, head, feat

# @torch.no_grad()
# def topk_list_for_text(text: str, layer_idx: int, k: int):
#     """
#     Returns list of (gid, (layer,head,feat), act) for top-k neurons for FULL text.
#     """
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)    # (nh, N)
#     flat = a.reshape(-1)                   # (nh*N,)
#     k = min(k, flat.numel())
#     vals, idxs = torch.topk(flat, k)
#     out = []
#     for v, ix in zip(vals.tolist(), idxs.tolist()):
#         head = ix // N
#         feat = ix % N
#         gid = neuron_id(layer_idx, head, feat, cfg.n_head, N)
#         out.append((gid, (layer_idx, head, feat), float(v)))
#     return out

# @torch.no_grad()
# def topk_set_for_text(text: str, layer_idx: int, k: int):
#     """
#     Returns set of top-k neuron IDs for FULL text.
#     """
#     hits = topk_list_for_text(text, layer_idx, k)
#     return set([gid for (gid, _, _) in hits])

# def jaccard(a: set, b: set) -> float:
#     return len(a & b) / (len(a | b) + EPS)

# @torch.no_grad()
# def activation_gid(text: str, layer_idx: int, gid: int) -> float:
#     """
#     Mean activation of ONE neuron gid on FULL text.
#     """
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return 0.0
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)    # (nh, N)
#     return float(a[head, feat].item())

# def read_list_until_end(prompt: str, min_n: int = 1):
#     print(prompt)
#     out = []
#     while True:
#         s = input("> ").strip()
#         if s.upper() == "END":
#             break
#         if s:
#             out.append(s)
#     if len(out) < min_n:
#         raise ValueError(f"Need at least {min_n} items.")
#     return out

# def shorten(s: str, maxlen: int = 18):
#     s = s.replace("\n", " ").strip()
#     return s if len(s) <= maxlen else (s[:maxlen-1] + "…")

# # -----------------------------
# # 4) Settings (choose what you want)
# # -----------------------------
# LAYERS = [4, 5]        # late layers
# TOPK_PRINT = 5         # demo table: top 5 neurons
# TOPK_SET = 50          # jaccard sets
# MAX_CAND = 400         # candidate shared neurons to score (speed)
# BASELINE_SEED = 0      # baseline randomness repeatable
# N_PROOF_CONCEPTS = 3   # top-N concepts to show proof plots for

# LAYERS = [L for L in LAYERS if 0 <= L < cfg.n_layer]

# # ============================================================
# # (A) DEMO MODE (single concept)
# # ============================================================
# print("\n=== DEMO MODE (one example) ===")
# demo_text = input("Enter an English word/sentence for demo:\n> ").strip()

# print("\nAvailable languages:")
# for k in ALL_LANGS:
#     print(" -", k)

# langs_raw = input("\nChoose languages (comma-separated):\n> ").strip()
# LANGS = [l.strip() for l in langs_raw.split(",") if l.strip() in ALL_LANGS]
# if len(LANGS) == 0:
#     LANGS = ["French"]  # safe default

# demo_trans = {lang: translate(lang, demo_text) for lang in LANGS}

# print("\nTranslations:")
# print("EN:", demo_text)
# for lang, txt in demo_trans.items():
#     print(f"{lang[:2].upper()}: {txt}")

# for LAYER in LAYERS:
#     print(f"\nTOP-{TOPK_PRINT} NEURONS (FULL TEXT) — Layer {LAYER}")
#     print(f"{'TEXT':<6} {'NEURON_ID':<8} {'(layer,head,feat)':<16} {'ACT'}")
#     print("-" * 48)

#     for gid, trip, act in topk_list_for_text(demo_text, LAYER, TOPK_PRINT):
#         print(f"{'EN':<6} {gid:<8} {str(trip):<16} {act:.3f}")
#     for lang, txt in demo_trans.items():
#         for gid, trip, act in topk_list_for_text(txt, LAYER, TOPK_PRINT):
#             print(f"{lang[:2].upper():<6} {gid:<8} {str(trip):<16} {act:.3f}")

# # ============================================================
# # (B) DATASET MODE: concepts + negatives (USER INPUT)
# # ============================================================
# print("\n\n=== DATASET MODE (monosemantic test) ===")
# concepts = read_list_until_end(
#     "Enter CONCEPT words (>=5). Type END to finish:\n"
#     "Example: dog, cat, school, doctor, music\n"
#     "Type END when done.",
#     min_n=5
# )

# NEG = read_list_until_end(
#     "\nEnter NEGATIVE texts (different meaning, >=3). Type END to finish:\n"
#     "Tip: use generic unrelated words: laptop, file, table, river, stone ...",
#     min_n=3
# )

# print("\nConcepts:", concepts)
# print("Negatives:", NEG)
# print("Languages:", LANGS)

# # Build translations for each concept
# concept_trans = []
# for w in concepts:
#     row = {"EN": w}
#     for lang in LANGS:
#         row[lang] = translate(lang, w)
#     concept_trans.append(row)

# # ============================================================
# # (C) JACCARD ACTUAL vs BASELINE (CLEAR BAR CHART)
# # ============================================================
# rng = np.random.default_rng(BASELINE_SEED)
# concept_names = [ct["EN"] for ct in concept_trans]

# def pick_other_index(i, n):
#     choices = [j for j in range(n) if j != i]
#     return int(rng.choice(choices))

# print("\n==============================")
# print("Jaccard(EN, translations) per concept + baseline")
# print("Baseline = EN compared to translations of a DIFFERENT random concept (same language list)")
# print("==============================")

# for LAYER in LAYERS:
#     actual = []
#     baseline = []
#     rows = []

#     for i, trans in enumerate(concept_trans):
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, LAYER, TOPK_SET)

#         # actual: EN vs its translations
#         js = []
#         for lang in LANGS:
#             js.append(jaccard(S_en, topk_set_for_text(trans[lang], LAYER, TOPK_SET)))
#         a = float(np.mean(js))
#         actual.append(a)

#         # baseline: EN vs translations of another random concept
#         j = pick_other_index(i, len(concept_trans))
#         other = concept_trans[j]
#         bjs = []
#         for lang in LANGS:
#             bjs.append(jaccard(S_en, topk_set_for_text(other[lang], LAYER, TOPK_SET)))
#         b = float(np.mean(bjs))
#         baseline.append(b)

#         rows.append((en, a, b, a - b))

#     # Print a clean table
#     print(f"\n--- Layer {LAYER} ---")
#     print(f"{'CONCEPT':<16} {'ACTUAL':>8} {'BASE':>8} {'MARGIN':>8}")
#     for (c, a, b, m) in rows:
#         print(f"{c:<16} {a:>8.3f} {b:>8.3f} {m:>8.3f}")

#     # Clear bar chart (Actual vs Baseline)
#     x = np.arange(len(concept_names))
#     plt.figure(figsize=(12, 4))
#     plt.bar(x - 0.2, actual, width=0.4, label="Actual (EN vs translations)")
#     plt.bar(x + 0.2, baseline, width=0.4, label="Baseline (EN vs other concept translations)")
#     plt.xticks(x, concept_names, rotation=35, ha="right")
#     plt.ylim(0, 1.0)
#     plt.title(f"Layer {LAYER}: Actual vs Baseline Jaccard overlap (TopK={TOPK_SET})")
#     plt.ylabel("Jaccard overlap")
#     plt.legend()
#     plt.tight_layout()
#     plt.show()

# # ============================================================
# # (D) PROOF PLOTS (TOP N concepts by margin): best shared neuron
# # ============================================================
# print("\n==============================")
# print(f"PROOF PLOTS: top {N_PROOF_CONCEPTS} concepts by (Actual - Baseline)")
# print("We find a shared neuron among POS texts (EN + translations), then compare vs NEG texts.")
# print("==============================")

# for LAYER in LAYERS:
#     # recompute margins for this layer to choose top concepts
#     margins = []
#     for i, trans in enumerate(concept_trans):
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, LAYER, TOPK_SET)

#         # actual
#         a = float(np.mean([
#             jaccard(S_en, topk_set_for_text(trans[lang], LAYER, TOPK_SET))
#             for lang in LANGS
#         ]))

#         # baseline
#         j = pick_other_index(i, len(concept_trans))
#         other = concept_trans[j]
#         b = float(np.mean([
#             jaccard(S_en, topk_set_for_text(other[lang], LAYER, TOPK_SET))
#             for lang in LANGS
#         ]))

#         margins.append((i, a - b, a, b))

#     margins.sort(key=lambda x: x[1], reverse=True)
#     top_idxs = [margins[k][0] for k in range(min(N_PROOF_CONCEPTS, len(margins)))]

#     print(f"\n--- Layer {LAYER}: top concepts by margin ---")
#     for k in range(min(N_PROOF_CONCEPTS, len(margins))):
#         i, m, a, b = margins[k]
#         print(f"{k+1}. {concept_trans[i]['EN']:<16} margin={m:>7.3f}  actual={a:>6.3f}  base={b:>6.3f}")

#     # For each selected concept: find best shared neuron and plot activations
#     for idx in top_idxs:
#         trans = concept_trans[idx]
#         concept = trans["EN"]

#         POS_texts = [trans["EN"]] + [trans[lang] for lang in LANGS]

#         # candidate neurons = intersection of TopK sets across POS texts
#         inter = None
#         for t in POS_texts:
#             S = topk_set_for_text(t, LAYER, TOPK_SET)
#             inter = S if inter is None else (inter & S)

#         if not inter:
#             print(f"\nLayer {LAYER} | concept='{concept}': no shared neurons in TopK intersection. (Try bigger TOPK_SET)")
#             continue

#         cand = list(inter)[:MAX_CAND]

#         # pick best neuron by selectivity = pos_mean - neg_mean
#         best = None
#         for gid in cand:
#             pos_mean = float(np.mean([activation_gid(t, LAYER, gid) for t in POS_texts]))
#             neg_mean = float(np.mean([activation_gid(t, LAYER, gid) for t in NEG]))
#             sel = pos_mean - neg_mean
#             if (best is None) or (sel > best[0]):
#                 best = (sel, gid, pos_mean, neg_mean)

#         sel, best_gid, pos_mean, neg_mean = best
#         lay, head, feat = decode_gid(best_gid)

#         # Build a single clear bar plot
#         labels = ["EN"] + [lang[:2].upper() for lang in LANGS] + [f"NEG{i+1}" for i in range(len(NEG))]
#         texts = POS_texts + NEG
#         acts = [activation_gid(t, LAYER, best_gid) for t in texts]

#         xtick = []
#         for lab, tx in zip(labels, texts):
#             xtick.append(f"{lab}\n{shorten(tx, 16)}")

#         plt.figure(figsize=(14, 4))
#         plt.bar(np.arange(len(labels)), acts)
#         plt.xticks(np.arange(len(labels)), xtick, rotation=0)
#         plt.ylabel("Mean activation")
#         plt.title(
#             f"Layer {LAYER} | concept='{concept}' | best neuron={best_gid} (head={head}, feat={feat})\n"
#             f"selectivity(pos-neg)={sel:.3f} | pos_mean={pos_mean:.3f} | neg_mean={neg_mean:.3f}"
#         )
#         plt.tight_layout()
#         plt.show()

#         # Print proof table (report-friendly)
#         print(f"\n[PROOF TABLE] Layer {LAYER} concept='{concept}' best_gid={best_gid} decode={(lay,head,feat)}")
#         print(f"  selectivity={sel:.4f}  pos_mean={pos_mean:.4f}  neg_mean={neg_mean:.4f}")
#         for lab, tx, av in zip(labels, texts, acts):
#             print(f"  {lab:<5} act={av:.4f} | {tx}")

# print("\nDONE ")
# print("Tip: If you want to restart fresh, re-run this cell. In Colab you can also Runtime -> Restart runtime.")


In [None]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


In [None]:
# # ============================================
# # BDH Phase-2 Probe (clear visuals + deterministic baseline + user NEG + optional GIF)
# # ============================================

# import os, re, math, random
# import numpy as np
# import torch
# import matplotlib.pyplot as plt
# import importlib.util
# from transformers import MarianMTModel, MarianTokenizer

# # ----------------------------
# # 0) Repro / clean start helpers
# # ----------------------------
# SEED = 1337
# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)

# def hard_reset_state():
#     """Notebook-level 'forget': clears translation cache + CUDA cache (does NOT restart runtime)."""
#     global _trans
#     _trans = {}
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()

# # ----------------------------
# # 1) Load BDH module + checkpoint
# # ----------------------------
# spec = importlib.util.spec_from_file_location("bdhmod", "/content/bdh_europarl_train_probe.py")
# bdhmod = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(bdhmod)

# BDHConfig   = bdhmod.BDHConfig
# BDH         = bdhmod.BDH
# ids_for_text= bdhmod.ids_for_text
# neuron_id   = bdhmod.neuron_id

# device = "cuda" if torch.cuda.is_available() else "cpu"

# CKPT = "checkpoints/bdh_europarl_bytes.pt"
# state = torch.load(CKPT, map_location=device)

# nh, D, N_tensor = state["encoder"].shape
# mult = int((N_tensor * nh) // D)

# cfg = BDHConfig(vocab_size=256, n_layer=6, n_embd=D, n_head=nh, mlp_internal_dim_multiplier=mult, dropout=0.0)
# model = BDH(cfg).to(device)
# model.load_state_dict(state, strict=True)
# model.eval()

# N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head  # per-head features

# print("BDH loaded")
# print("Model config:", {"n_layer": cfg.n_layer, "n_embd": cfg.n_embd, "n_head": cfg.n_head, "mult": cfg.mlp_internal_dim_multiplier, "N_per_head": N})

# # ----------------------------
# # 2) Translation (cached) + cache clear
# # ----------------------------
# ALL_LANGS = {
#     "German":  "Helsinki-NLP/opus-mt-en-de",
#     "French":  "Helsinki-NLP/opus-mt-en-fr",
#     "Spanish": "Helsinki-NLP/opus-mt-en-es",
#     "Italian": "Helsinki-NLP/opus-mt-en-it",
# }
# _trans = {}

# def translate(lang, text):
#     if lang not in _trans:
#         tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
#         mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
#         _trans[lang] = (tok, mdl)
#     tok, mdl = _trans[lang]
#     batch = tok([text], return_tensors="pt", padding=True).to(device)
#     out = mdl.generate(**batch, max_new_tokens=64)
#     return tok.decode(out[0], skip_special_tokens=True)

# # ----------------------------
# # 3) Core helpers
# # ----------------------------
# EPS = 1e-9
# LAYERS = [4, 5]
# LAYERS = [L for L in LAYERS if 0 <= L < cfg.n_layer]

# TOPK_PRINT = 5
# TOPK_SET   = 50
# MAX_CAND   = 400

# def tokenize_words(s: str):
#     return re.findall(r"[A-Za-zÀ-ÿ]+", s.lower())

# @torch.no_grad()
# def topk_set_for_text(text: str, layer_idx: int, k: int):
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)  # (nh, N)
#     flat = a.reshape(-1)
#     k = min(k, flat.numel())
#     _, idxs = torch.topk(flat, k)
#     S = set()
#     for ix in idxs.tolist():
#         head = ix // N
#         feat = ix % N
#         S.add(neuron_id(layer_idx, head, feat, cfg.n_head, N))
#     return S

# def jaccard(a: set, b: set) -> float:
#     return len(a & b) / (len(a | b) + EPS)

# def decode_gid(gid: int):
#     per_layer = cfg.n_head * N
#     layer = gid // per_layer
#     rem = gid % per_layer
#     head = rem // N
#     feat = rem % N
#     return layer, head, feat

# @torch.no_grad()
# def activation_gid_mean(text: str, layer_idx: int, gid: int) -> float:
#     """Mean activation of ONE neuron over positions."""
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return 0.0
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0]  # (nh, T, N) since B=1
#     return float(a[head, :, feat].mean().item())

# @torch.no_grad()
# def activation_gid_over_positions(text: str, layer_idx: int, gid: int):
#     """Per-position activations: (T,) for ONE neuron (good for GIF)."""
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return np.array([0.0])
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0]  # (nh, T, N)
#     return a[head, :, feat].detach().float().cpu().numpy()

# def read_list_until_end(prompt, min_n=1):
#     print(prompt)
#     out = []
#     while True:
#         s = input("> ").strip()
#         if s.upper() == "END":
#             break
#         if s:
#             out.append(s)
#     if len(out) < min_n:
#         raise ValueError(f"Need at least {min_n} items.")
#     return out

# # ----------------------------
# # 4) DEMO MODE (Top-K neuron IDs, full-text)
# # ----------------------------
# print("\n=== DEMO MODE ===")
# sentence = input("Enter an English word/sentence:\n> ").strip()

# print("\nAvailable languages:")
# for k in ALL_LANGS: print(" -", k)

# langs_raw = input("\nChoose languages (comma-separated, blank=all):\n> ").strip()
# LANGS = [l.strip() for l in langs_raw.split(",") if l.strip() in ALL_LANGS]
# if len(LANGS) == 0:
#     LANGS = list(ALL_LANGS.keys())

# translations = {lang: translate(lang, sentence) for lang in LANGS}

# print("\nTranslations:")
# print("EN:", sentence)
# for lang, txt in translations.items():
#     print(f"{lang[:2].upper()}: {txt}")

# @torch.no_grad()
# def topk_list_fulltext(text, layer_idx, k=TOPK_PRINT):
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)  # (nh,N)
#     flat = a.reshape(-1)
#     k = min(k, flat.numel())
#     vals, idxs = torch.topk(flat, k)
#     out = []
#     for v, ix in zip(vals.tolist(), idxs.tolist()):
#         head = ix // N
#         feat = ix % N
#         gid = neuron_id(layer_idx, head, feat, cfg.n_head, N)
#         out.append((gid, (layer_idx, head, feat), float(v)))
#     return out

# for L in LAYERS:
#     print(f"\nTOP-{TOPK_PRINT} NEURONS (FULL TEXT) — Layer {L}")
#     print(f"{'TEXT':<6} {'NEURON_ID':<10} {'(layer,head,feat)':<18} {'ACT'}")
#     print("-"*60)
#     for gid, trip, act in topk_list_fulltext(sentence, L):
#         print(f"{'EN':<6} {gid:<10} {str(trip):<18} {act:.3f}")
#     for lang, txt in translations.items():
#         for gid, trip, act in topk_list_fulltext(txt, L):
#             print(f"{lang[:2].upper():<6} {gid:<10} {str(trip):<18} {act:.3f}")

# # ----------------------------
# # 5) DATASET MODE (concepts + NEG words)
# # ----------------------------
# print("\n=== DATASET MODE (monosemantic test) ===")
# concepts = read_list_until_end(
#     "Enter CONCEPT words (>=5). Type END to finish:",
#     min_n=5
# )

# # translate each concept into chosen languages
# concept_trans = []
# for w in concepts:
#     trans = {"EN": w}
#     for lang in LANGS:
#         trans[lang] = translate(lang, w)
#     concept_trans.append(trans)

# # NEG words: ALWAYS ask user (or AUTO)
# print("\nNEGATIVE SET:")
# print("Type negative words (different meaning).")
# print("Tip: use unrelated categories (tools, places, verbs, numbers).")
# neg_mode = input("Type 'AUTO' to use a default negative list, else press Enter to type manually:\n> ").strip().upper()

# DEFAULT_NEG = ["table","river","engine","laptop","file","money","music","doctor","mountain","battery","cloud","kitchen","orange","ten","run"]

# if neg_mode == "AUTO":
#     NEG = [w for w in DEFAULT_NEG if w.lower() not in set(c.lower() for c in concepts)]
#     NEG = NEG[:8]  # keep small and readable
#     print("Using AUTO NEG:", NEG)
# else:
#     NEG = read_list_until_end("Enter NEGATIVE words (>=3). Type END to finish:", min_n=3)

# # ----------------------------
# # 6) Jaccard actual vs baseline (DETERMINISTIC baseline)
# # ----------------------------
# def baseline_pair_index(i, n):
#     """Deterministic baseline: compare concept i with translations of concept (i+1)%n."""
#     return (i + 1) % n

# def compute_actual_and_baseline(layer_idx):
#     names = [ct["EN"] for ct in concept_trans]
#     actual_list = []
#     base_list = []

#     for i, trans in enumerate(concept_trans):
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, layer_idx, TOPK_SET)

#         # actual = mean jaccard (EN vs its translations)
#         js = []
#         for lang in LANGS:
#             S_tr = topk_set_for_text(trans[lang], layer_idx, TOPK_SET)
#             js.append(jaccard(S_en, S_tr))
#         actual = float(np.mean(js))

#         # baseline = EN vs OTHER concept translations
#         j = baseline_pair_index(i, len(concept_trans))
#         other = concept_trans[j]
#         base_js = []
#         for lang in LANGS:
#             S_other = topk_set_for_text(other[lang], layer_idx, TOPK_SET)
#             base_js.append(jaccard(S_en, S_other))
#         baseline = float(np.mean(base_js))

#         actual_list.append(actual)
#         base_list.append(baseline)

#     return names, np.array(actual_list), np.array(base_list)

# # Plot: clear actual vs baseline per layer
# results = {}
# for L in LAYERS:
#     names, actual_arr, base_arr = compute_actual_and_baseline(L)
#     results[L] = (names, actual_arr, base_arr)

#     print(f"\nLayer {L}: actual_mean={actual_arr.mean():.3f}  baseline_mean={base_arr.mean():.3f}  (want actual > baseline)")

#     x = np.arange(len(names))
#     plt.figure(figsize=(12,4))
#     plt.bar(x - 0.2, actual_arr, width=0.4, label="Actual (EN vs its translations)")
#     plt.bar(x + 0.2, base_arr,   width=0.4, label="Baseline (EN vs other-concept translations)")
#     plt.xticks(x, names, rotation=45, ha="right")
#     plt.ylim(0, 1)
#     plt.ylabel("Jaccard overlap")
#     plt.title(f"Layer {L}: Actual vs Baseline Jaccard (TopK={TOPK_SET})")
#     plt.legend()
#     plt.tight_layout()
#     plt.show()

# # ----------------------------
# # 7) Best shared neuron per concept (with selectivity)
# # ----------------------------
# @torch.no_grad()
# @torch.no_grad()
# def best_shared_neuron_for_concept(i, layer_idx):
#     trans = concept_trans[i]
#     POS = [trans["EN"]] + [trans[lang] for lang in LANGS]

#     inter = None
#     for t in POS:
#         S = topk_set_for_text(t, layer_idx, TOPK_SET)
#         inter = S if inter is None else (inter & S)

#     S_en = topk_set_for_text(POS[0], layer_idx, TOPK_SET)

#     if inter is None:
#         inter = set()

#     cand_set = set(inter) | set(S_en)
#     if len(cand_set) == 0:
#         return None

#     cand = list(cand_set)[:MAX_CAND]

#     best = None
#     for gid in cand:
#         pos_mean = float(np.mean([activation_gid_mean(t, layer_idx, gid) for t in POS]))
#         neg_mean = float(np.mean([activation_gid_mean(t, layer_idx, gid) for t in NEG]))
#         sel = pos_mean - neg_mean
#         if (best is None) or (sel > best[1]):
#             best = (gid, sel, pos_mean, neg_mean)

#     return best
# def short(s, n=18):
#     s = s.replace("\n"," ")
#     return s if len(s) <= n else s[:n-3] + "..."

# TOP_SHOW = min(3, len(concept_trans))

# for L in LAYERS:
#     names, actual_arr, base_arr = results[L]
#     margins = actual_arr - base_arr
#     order = np.argsort(-margins)[:TOP_SHOW]

#     print(f"\n=== Layer {L}: proof plots for top {TOP_SHOW} margin concepts ===")

#     for idx in order:
#         concept = concept_trans[idx]["EN"]
#         best = best_shared_neuron_for_concept(idx, L)
#         if best is None:
#             print(f"Concept '{concept}': no shared neuron intersection.")
#             continue

#         gid, sel, pos_m, neg_m = best

#         POS_texts  = [concept_trans[idx]["EN"]] + [concept_trans[idx][lang] for lang in LANGS]
#         POS_labels = ["EN"] + [lang[:2].upper() for lang in LANGS]

#         NEG_texts  = NEG
#         NEG_labels = [f"NEG{i+1}" for i in range(len(NEG_texts))]

#         all_texts  = POS_texts + NEG_texts
#         all_labels = POS_labels + NEG_labels

#         acts = np.array([activation_gid_mean(t, L, gid) for t in all_texts], dtype=np.float64)

#         # 7A) Activation BAR plot (clear labels)
#         plt.figure(figsize=(12,4))
#         plt.bar(np.arange(len(all_labels)), acts)
#         plt.xticks(np.arange(len(all_labels)),
#                    [f"{lab}\n{short(tx)}" for lab, tx in zip(all_labels, all_texts)],
#                    rotation=0)
#         plt.ylabel("Mean activation")
#         plt.title(
#             f"Layer {L} | concept='{concept}' | best neuron={gid} decode={decode_gid(gid)}\n"
#             f"selectivity(pos-neg)={sel:.3f} (pos_mean={pos_m:.3f}, neg_mean={neg_m:.3f})"
#         )
#         plt.tight_layout()
#         plt.show()

#         # 7B) POS vs NEG boxplot (shows “fires for POS, not NEG” much clearer)
#         pos_vals = acts[:len(POS_texts)]
#         neg_vals = acts[len(POS_texts):]

#         plt.figure(figsize=(6,4))
#         plt.boxplot([pos_vals, neg_vals], labels=["POS (EN+translations)", "NEG"], showmeans=True)
#         plt.ylabel("Mean activation")
#         plt.title(f"Layer {L} neuron {gid}: POS vs NEG activation distribution")
#         plt.tight_layout()
#         plt.show()

#         print(f"\nConcept: {concept}")
#         print(f"  Neuron: {gid} decode={decode_gid(gid)}")
#         print(f"  Selectivity: {sel:.4f}  pos_mean={pos_m:.4f}  neg_mean={neg_m:.4f}")
#         for lab, tx, av in zip(all_labels, all_texts, acts.tolist()):
#             print(f"    {lab:<5} act={av:.4f} | {tx}")

# # ----------------------------
# # 8) OPTIONAL: Make a GIF (activation over byte positions)
# # ----------------------------
# make_gif = input("\nMake a GIF of per-byte activation for ONE concept? (y/n)\n> ").strip().lower() == "y"
# if make_gif:
#     import imageio.v2 as imageio
#     from IPython.display import Image, display

#     # pick: best margin concept in first layer for GIF
#     L = LAYERS[0]
#     names, actual_arr, base_arr = results[L]
#     best_idx = int(np.argmax(actual_arr - base_arr))
#     concept = concept_trans[best_idx]["EN"]
#     best = best_shared_neuron_for_concept(best_idx, L)

#     if best is None:
#         print("No shared neuron found for GIF.")
#     else:
#         gid, sel, pos_m, neg_m = best

#         en_text = concept_trans[best_idx]["EN"]
#         other_lang = LANGS[0]  # first chosen language
#         tr_text = concept_trans[best_idx][other_lang]

#         seqs = [("EN", en_text), (other_lang[:2].upper(), tr_text)]

#         gif_path = "bdh_activation.gif"
#         out_frames = []

#         for tag, text in seqs:
#             vals = activation_gid_over_positions(text, L, gid)  # (T,)

#             max_frames = min(len(vals), 120)  # cap frames so it renders nicely
#             for t in range(1, max_frames + 1):
#                 fig = plt.figure(figsize=(8, 3))
#                 plt.plot(np.arange(t), vals[:t])
#                 plt.ylim(0, max(vals.max() * 1.1, 1e-3))
#                 plt.xlabel("Byte position")
#                 plt.ylabel("Activation")
#                 plt.title(f"Layer {L} neuron {gid} | {tag} '{short(text, 30)}' | t={t}/{len(vals)}")
#                 plt.tight_layout()

#                 fig.canvas.draw()

#                 # Matplotlib 3.9+ safe rendering
#                 buf = np.asarray(fig.canvas.buffer_rgba())      # (H, W, 4)
#                 out_frames.append(buf[:, :, :3].copy())         # (H, W, 3)

#                 plt.close(fig)

#         imageio.mimsave(gif_path, out_frames, fps=10)
#         print("Saved GIF:", gif_path)

#         # show gif in output
#         display(Image(filename=gif_path))


In [None]:
# # ============================================
# # BDH Phase-2 Probe (PLOTS + GIF working reliably in Colab/Jupyter)
# # Fixes:
# # - Forces inline plotting
# # - Fixes best_shared_neuron_for_concept (no dead code / no double decorator)
# # - Uses Matplotlib 3.9+ safe frame capture for GIF (buffer_rgba)
# # - Displays the GIF in notebook output
# # ============================================

# # --- (A) Notebook display setup: MUST be at top in Colab/Jupyter ---
# import matplotlib
# import matplotlib.pyplot as plt

# # Force inline backend in notebooks (safe even if already inline)
# try:
#     get_ipython().run_line_magic("matplotlib", "inline")
# except Exception:
#     pass

# # Make sure interactive isn't blocking (Colab friendly)
# plt.ioff()

# import os, re, math, random
# import numpy as np
# import torch
# import importlib.util
# from transformers import MarianMTModel, MarianTokenizer

# # ----------------------------
# # 0) Repro / clean start helpers
# # ----------------------------
# SEED = 1337
# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)

# def hard_reset_state():
#     """Notebook-level 'forget': clears translation cache + CUDA cache (does NOT restart runtime)."""
#     global _trans
#     _trans = {}
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()

# # ----------------------------
# # 1) Load BDH module + checkpoint
# # ----------------------------
# spec = importlib.util.spec_from_file_location("bdhmod", "/content/bdh_europarl_train_probe.py")
# bdhmod = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(bdhmod)

# BDHConfig    = bdhmod.BDHConfig
# BDH          = bdhmod.BDH
# ids_for_text = bdhmod.ids_for_text
# neuron_id    = bdhmod.neuron_id

# device = "cuda" if torch.cuda.is_available() else "cpu"

# CKPT = "checkpoints/bdh_europarl_bytes.pt"
# state = torch.load(CKPT, map_location=device)

# nh, D, N_tensor = state["encoder"].shape
# mult = int((N_tensor * nh) // D)

# cfg = BDHConfig(vocab_size=256, n_layer=6, n_embd=D, n_head=nh,
#                 mlp_internal_dim_multiplier=mult, dropout=0.0)

# model = BDH(cfg).to(device)
# model.load_state_dict(state, strict=True)
# model.eval()

# N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head  # per-head features

# print("BDH loaded")
# print("Model config:", {
#     "n_layer": cfg.n_layer, "n_embd": cfg.n_embd,
#     "n_head": cfg.n_head, "mult": cfg.mlp_internal_dim_multiplier,
#     "N_per_head": N
# })

# # ----------------------------
# # 2) Translation (cached)
# # ----------------------------
# ALL_LANGS = {
#     "German":  "Helsinki-NLP/opus-mt-en-de",
#     "French":  "Helsinki-NLP/opus-mt-en-fr",
#     "Spanish": "Helsinki-NLP/opus-mt-en-es",
#     "Italian": "Helsinki-NLP/opus-mt-en-it",
# }
# _trans = {}

# def translate(lang, text):
#     if lang not in _trans:
#         tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
#         mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
#         _trans[lang] = (tok, mdl)
#     tok, mdl = _trans[lang]
#     batch = tok([text], return_tensors="pt", padding=True).to(device)
#     out = mdl.generate(**batch, max_new_tokens=64)
#     return tok.decode(out[0], skip_special_tokens=True)

# # ----------------------------
# # 3) Core helpers
# # ----------------------------
# EPS = 1e-9
# LAYERS = [4, 5]
# LAYERS = [L for L in LAYERS if 0 <= L < cfg.n_layer]

# TOPK_PRINT = 5
# TOPK_SET   = 50
# MAX_CAND   = 400

# @torch.no_grad()
# def topk_set_for_text(text: str, layer_idx: int, k: int):
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)  # (nh, N)
#     flat = a.reshape(-1)
#     k = min(k, flat.numel())
#     _, idxs = torch.topk(flat, k)
#     S = set()
#     for ix in idxs.tolist():
#         head = ix // N
#         feat = ix % N
#         S.add(neuron_id(layer_idx, head, feat, cfg.n_head, N))
#     return S

# def jaccard(a: set, b: set) -> float:
#     return len(a & b) / (len(a | b) + EPS)

# def decode_gid(gid: int):
#     per_layer = cfg.n_head * N
#     layer = gid // per_layer
#     rem = gid % per_layer
#     head = rem // N
#     feat = rem % N
#     return layer, head, feat

# @torch.no_grad()
# def activation_gid_mean(text: str, layer_idx: int, gid: int) -> float:
#     """Mean activation of ONE neuron over positions."""
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return 0.0
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0]  # (nh, T, N)
#     return float(a[head, :, feat].mean().item())

# @torch.no_grad()
# def activation_gid_over_positions(text: str, layer_idx: int, gid: int):
#     """Per-position activations (T,) for ONE neuron."""
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return np.array([0.0])
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0]
#     return a[head, :, feat].detach().float().cpu().numpy()

# def read_list_until_end(prompt, min_n=1):
#     print(prompt)
#     out = []
#     while True:
#         s = input("> ").strip()
#         if s.upper() == "END":
#             break
#         if s:
#             out.append(s)
#     if len(out) < min_n:
#         raise ValueError(f"Need at least {min_n} items.")
#     return out

# def short(s, n=18):
#     s = s.replace("\n", " ")
#     return s if len(s) <= n else s[:n-3] + "..."

# def showfig():
#     """Make sure figures actually render in notebooks."""
#     plt.show()
#     plt.close("all")

# # ----------------------------
# # 4) DEMO MODE (Top-K neurons)
# # ----------------------------
# print("\n=== DEMO MODE ===")
# sentence = input("Enter an English word/sentence:\n> ").strip()

# print("\nAvailable languages:")
# for k in ALL_LANGS:
#     print(" -", k)

# langs_raw = input("\nChoose languages (comma-separated, blank=all):\n> ").strip()
# LANGS = [l.strip() for l in langs_raw.split(",") if l.strip() in ALL_LANGS]
# if len(LANGS) == 0:
#     LANGS = list(ALL_LANGS.keys())

# translations = {lang: translate(lang, sentence) for lang in LANGS}

# print("\nTranslations:")
# print("EN:", sentence)
# for lang, txt in translations.items():
#     print(f"{lang[:2].upper()}: {txt}")

# @torch.no_grad()
# def topk_list_fulltext(text, layer_idx, k=TOPK_PRINT):
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)  # (nh,N)
#     flat = a.reshape(-1)
#     k = min(k, flat.numel())
#     vals, idxs = torch.topk(flat, k)
#     out = []
#     for v, ix in zip(vals.tolist(), idxs.tolist()):
#         head = ix // N
#         feat = ix % N
#         gid = neuron_id(layer_idx, head, feat, cfg.n_head, N)
#         out.append((gid, (layer_idx, head, feat), float(v)))
#     return out

# for L in LAYERS:
#     print(f"\nTOP-{TOPK_PRINT} NEURONS (FULL TEXT) — Layer {L}")
#     print(f"{'TEXT':<6} {'NEURON_ID':<10} {'(layer,head,feat)':<18} {'ACT'}")
#     print("-"*60)
#     for gid, trip, act in topk_list_fulltext(sentence, L):
#         print(f"{'EN':<6} {gid:<10} {str(trip):<18} {act:.3f}")
#     for lang, txt in translations.items():
#         for gid, trip, act in topk_list_fulltext(txt, L):
#             print(f"{lang[:2].upper():<6} {gid:<10} {str(trip):<18} {act:.3f}")

# # ----------------------------
# # 5) DATASET MODE (concepts + NEG)
# # ----------------------------
# print("\n=== DATASET MODE (monosemantic test) ===")
# concepts = read_list_until_end("Enter CONCEPT words (>=5). Type END to finish:", min_n=5)

# concept_trans = []
# for w in concepts:
#     trans = {"EN": w}
#     for lang in LANGS:
#         trans[lang] = translate(lang, w)
#     concept_trans.append(trans)

# print("\nNEGATIVE SET:")
# print("Type negative words (different meaning).")
# print("Tip: use unrelated categories (tools, places, verbs, numbers).")
# neg_mode = input("Type 'AUTO' to use a default negative list, else press Enter to type manually:\n> ").strip().upper()

# DEFAULT_NEG = ["table","river","engine","laptop","file","money","music","doctor","mountain","battery","cloud","kitchen","orange","ten","run"]
# if neg_mode == "AUTO":
#     NEG = [w for w in DEFAULT_NEG if w.lower() not in set(c.lower() for c in concepts)]
#     NEG = NEG[:8]
#     print("Using AUTO NEG:", NEG)
# else:
#     NEG = read_list_until_end("Enter NEGATIVE words (>=3). Type END to finish:", min_n=3)

# # ----------------------------
# # 6) Actual vs Baseline Jaccard
# # ----------------------------
# def baseline_pair_index(i, n):
#     return (i + 1) % n

# def compute_actual_and_baseline(layer_idx):
#     names = [ct["EN"] for ct in concept_trans]
#     actual_list, base_list = [], []

#     for i, trans in enumerate(concept_trans):
#         en = trans["EN"]
#         S_en = topk_set_for_text(en, layer_idx, TOPK_SET)

#         js = []
#         for lang in LANGS:
#             S_tr = topk_set_for_text(trans[lang], layer_idx, TOPK_SET)
#             js.append(jaccard(S_en, S_tr))
#         actual = float(np.mean(js))

#         j = baseline_pair_index(i, len(concept_trans))
#         other = concept_trans[j]
#         base_js = []
#         for lang in LANGS:
#             S_other = topk_set_for_text(other[lang], layer_idx, TOPK_SET)
#             base_js.append(jaccard(S_en, S_other))
#         baseline = float(np.mean(base_js))

#         actual_list.append(actual)
#         base_list.append(baseline)

#     return names, np.array(actual_list), np.array(base_list)

# results = {}
# for L in LAYERS:
#     names, actual_arr, base_arr = compute_actual_and_baseline(L)
#     results[L] = (names, actual_arr, base_arr)

#     print(f"\nLayer {L}: actual_mean={actual_arr.mean():.3f}  baseline_mean={base_arr.mean():.3f}  (want actual > baseline)")

#     x = np.arange(len(names))
#     plt.figure(figsize=(12,4))
#     plt.bar(x - 0.2, actual_arr, width=0.4, label="Actual (EN vs its translations)")
#     plt.bar(x + 0.2, base_arr,   width=0.4, label="Baseline (EN vs other-concept translations)")
#     plt.xticks(x, names, rotation=45, ha="right")
#     plt.ylim(0, 1)
#     plt.ylabel("Jaccard overlap")
#     plt.title(f"Layer {L}: Actual vs Baseline Jaccard (TopK={TOPK_SET})")
#     plt.legend()
#     plt.tight_layout()
#     showfig()

# # ----------------------------
# # 7) Best shared neuron per concept (FIXED)
# # ----------------------------
# @torch.no_grad()
# def best_shared_neuron_for_concept(i, layer_idx):
#     trans = concept_trans[i]
#     POS = [trans["EN"]] + [trans[lang] for lang in LANGS]

#     inter = None
#     for t in POS:
#         S = topk_set_for_text(t, layer_idx, TOPK_SET)
#         inter = S if inter is None else (inter & S)

#     S_en = topk_set_for_text(POS[0], layer_idx, TOPK_SET)
#     if inter is None:
#         inter = set()

#     cand_set = set(inter) | set(S_en)
#     if len(cand_set) == 0:
#         return None

#     cand = list(cand_set)[:MAX_CAND]

#     best = None
#     for gid in cand:
#         pos_mean = float(np.mean([activation_gid_mean(t, layer_idx, gid) for t in POS]))
#         neg_mean = float(np.mean([activation_gid_mean(t, layer_idx, gid) for t in NEG]))
#         sel = pos_mean - neg_mean
#         if (best is None) or (sel > best[1]):
#             best = (gid, sel, pos_mean, neg_mean)

#     return best  # (gid, sel, pos_mean, neg_mean)

# TOP_SHOW = min(3, len(concept_trans))

# for L in LAYERS:
#     names, actual_arr, base_arr = results[L]
#     margins = actual_arr - base_arr
#     order = np.argsort(-margins)[:TOP_SHOW]

#     print(f"\n=== Layer {L}: proof plots for top {TOP_SHOW} margin concepts ===")

#     for idx in order:
#         concept = concept_trans[idx]["EN"]
#         best = best_shared_neuron_for_concept(idx, L)
#         if best is None:
#             print(f"Concept '{concept}': no shared neuron candidate set.")
#             continue

#         gid, sel, pos_m, neg_m = best

#         POS_texts  = [concept_trans[idx]["EN"]] + [concept_trans[idx][lang] for lang in LANGS]
#         POS_labels = ["EN"] + [lang[:2].upper() for lang in LANGS]

#         NEG_texts  = NEG
#         NEG_labels = [f"NEG{i+1}" for i in range(len(NEG_texts))]

#         all_texts  = POS_texts + NEG_texts
#         all_labels = POS_labels + NEG_labels

#         acts = np.array([activation_gid_mean(t, L, gid) for t in all_texts], dtype=np.float64)

#         # 7A) Bar plot
#         plt.figure(figsize=(12,4))
#         plt.bar(np.arange(len(all_labels)), acts)
#         plt.xticks(np.arange(len(all_labels)),
#                    [f"{lab}\n{short(tx)}" for lab, tx in zip(all_labels, all_texts)],
#                    rotation=0)
#         plt.ylabel("Mean activation")
#         plt.title(
#             f"Layer {L} | concept='{concept}' | best neuron={gid} decode={decode_gid(gid)}\n"
#             f"selectivity(pos-neg)={sel:.3f} (pos_mean={pos_m:.3f}, neg_mean={neg_m:.3f})"
#         )
#         plt.tight_layout()
#         showfig()

#         # 7B) Boxplot (Matplotlib 3.9+ prefers tick_labels, but labels still works; keep simple)
#         pos_vals = acts[:len(POS_texts)]
#         neg_vals = acts[len(POS_texts):]

#         plt.figure(figsize=(6,4))
#         plt.boxplot([pos_vals, neg_vals], tick_labels=["POS (EN+translations)", "NEG"], showmeans=True)
#         plt.ylabel("Mean activation")
#         plt.title(f"Layer {L} neuron {gid}: POS vs NEG activation distribution")
#         plt.tight_layout()
#         showfig()

#         print(f"\nConcept: {concept}")
#         print(f"  Neuron: {gid} decode={decode_gid(gid)}")
#         print(f"  Selectivity: {sel:.4f}  pos_mean={pos_m:.4f}  neg_mean={neg_m:.4f}")
#         for lab, tx, av in zip(all_labels, all_texts, acts.tolist()):
#             print(f"    {lab:<5} act={av:.4f} | {tx}")

# # ----------------------------
# # 8) OPTIONAL: GIF (Matplotlib 3.9+ safe)
# # ----------------------------
# make_gif = input("\nMake a GIF of per-byte activation for ONE concept? (y/n)\n> ").strip().lower() == "y"
# if make_gif:
#     import imageio.v2 as imageio
#     from IPython.display import Image, display

#     L = LAYERS[0]
#     names, actual_arr, base_arr = results[L]
#     best_idx = int(np.argmax(actual_arr - base_arr))
#     concept = concept_trans[best_idx]["EN"]
#     best = best_shared_neuron_for_concept(best_idx, L)

#     if best is None:
#         print("No shared neuron found for GIF.")
#     else:
#         gid, sel, pos_m, neg_m = best

#         en_text = concept_trans[best_idx]["EN"]
#         other_lang = LANGS[0]
#         tr_text = concept_trans[best_idx][other_lang]

#         seqs = [("EN", en_text), (other_lang[:2].upper(), tr_text)]

#         gif_path = "bdh_activation.gif"
#         out_frames = []

#         for tag, text in seqs:
#             vals = activation_gid_over_positions(text, L, gid)
#             max_frames = min(len(vals), 120)

#             for t in range(1, max_frames + 1):
#                 fig = plt.figure(figsize=(8, 3))
#                 plt.plot(np.arange(t), vals[:t])
#                 plt.ylim(0, max(vals.max() * 1.1, 1e-3))
#                 plt.xlabel("Byte position")
#                 plt.ylabel("Activation")
#                 plt.title(f"Layer {L} neuron {gid} | {tag} '{short(text, 30)}' | t={t}/{len(vals)}")
#                 plt.tight_layout()

#                 fig.canvas.draw()
#                 buf = np.asarray(fig.canvas.buffer_rgba())      # (H, W, 4)
#                 out_frames.append(buf[:, :, :3].copy())         # (H, W, 3)
#                 plt.close(fig)

#         imageio.mimsave(gif_path, out_frames, fps=10)
#         print("Saved GIF:", gif_path)

#         # show gif in output (works in notebook)
#         display(Image(filename=gif_path))

# print("\nDONE.")


In [None]:
# # ============================================
# # BDH Phase-2 Probe (STRONGER NEG + Strong baseline + EXPLAINABLE GIF)
# # Colab/Jupyter ready (matplotlib inline) + Matplotlib 3.9+ safe capture
# # ============================================

# import os, re, random, html
# import numpy as np
# import torch
# import matplotlib.pyplot as plt
# import importlib.util

# from transformers import MarianMTModel, MarianTokenizer

# # ---------- IMPORTANT for notebooks ----------
# try:
#     get_ipython  # noqa
#     get_ipython().run_line_magic("matplotlib", "inline")
# except Exception:
#     pass

# plt.ioff()  # non-blocking

# # ----------------------------
# # 0) Repro / clean reset helpers
# # ----------------------------
# SEED = 1337

# def seed_all(seed=SEED):
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)

# seed_all(SEED)

# _trans = {}  # translation cache

# def hard_reset_state(also_seed=True):
#     global _trans
#     _trans = {}
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#     if also_seed:
#         seed_all(SEED)
#     print("[reset] cleared translation cache + cuda cache (if any) + reseeded.")

# # ----------------------------
# # 1) Load BDH module + checkpoint
# # ----------------------------
# BDH_SCRIPT_PATH = "/content/bdh_europarl_train_probe.py"   # change if needed
# CKPT_PATH = "checkpoints/bdh_europarl_bytes.pt"           # change if needed

# spec = importlib.util.spec_from_file_location("bdhmod", BDH_SCRIPT_PATH)
# bdhmod = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(bdhmod)

# BDHConfig    = bdhmod.BDHConfig
# BDH          = bdhmod.BDH
# ids_for_text = bdhmod.ids_for_text
# neuron_id    = bdhmod.neuron_id

# device = "cuda" if torch.cuda.is_available() else "cpu"

# state = torch.load(CKPT_PATH, map_location=device)
# nh, D, N_tensor = state["encoder"].shape
# mult = int((N_tensor * nh) // D)

# cfg = BDHConfig(
#     vocab_size=256, n_layer=6, n_embd=D, n_head=nh,
#     mlp_internal_dim_multiplier=mult, dropout=0.0
# )
# model = BDH(cfg).to(device)
# model.load_state_dict(state, strict=True)
# model.eval()

# N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head  # per-head features
# print("BDH loaded")
# print("Model config:", {"n_layer": cfg.n_layer, "n_embd": cfg.n_embd, "n_head": cfg.n_head,
#                       "mult": cfg.mlp_internal_dim_multiplier, "N_per_head": N})

# # ----------------------------
# # 2) Translation (cached)
# # ----------------------------
# ALL_LANGS = {
#     "German":  "Helsinki-NLP/opus-mt-en-de",
#     "French":  "Helsinki-NLP/opus-mt-en-fr",
#     "Spanish": "Helsinki-NLP/opus-mt-en-es",
#     "Italian": "Helsinki-NLP/opus-mt-en-it",
# }

# def translate(lang, text):
#     global _trans
#     if lang not in _trans:
#         tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
#         mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
#         _trans[lang] = (tok, mdl)
#     tok, mdl = _trans[lang]
#     batch = tok([text], return_tensors="pt", padding=True).to(device)
#     out = mdl.generate(**batch, max_new_tokens=64)
#     return tok.decode(out[0], skip_special_tokens=True)

# # ----------------------------
# # 3) Core helpers
# # ----------------------------
# EPS = 1e-9
# LAYERS = [4, 5]
# LAYERS = [L for L in LAYERS if 0 <= L < cfg.n_layer]

# TOPK_PRINT = 5
# TOPK_SET   = 200   # IMPORTANT: higher = more stable; 50 was too small
# MAX_CAND   = 800

# @torch.no_grad()
# def topk_set_for_text(text: str, layer_idx: int, k: int):
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)  # (nh, N)
#     flat = a.reshape(-1)
#     k = min(k, flat.numel())
#     _, idxs = torch.topk(flat, k)
#     S = set()
#     for ix in idxs.tolist():
#         head = ix // N
#         feat = ix % N
#         S.add(neuron_id(layer_idx, head, feat, cfg.n_head, N))
#     return S

# def jaccard(a: set, b: set) -> float:
#     return len(a & b) / (len(a | b) + EPS)

# def decode_gid(gid: int):
#     per_layer = cfg.n_head * N
#     layer = gid // per_layer
#     rem = gid % per_layer
#     head = rem // N
#     feat = rem % N
#     return layer, head, feat

# @torch.no_grad()
# def activation_gid_mean(text: str, layer_idx: int, gid: int) -> float:
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return 0.0
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0]  # (nh, T, N)
#     return float(a[head, :, feat].mean().item())

# @torch.no_grad()
# def activation_gid_over_positions(text: str, layer_idx: int, gid: int):
#     layer, head, feat = decode_gid(gid)
#     if layer != layer_idx:
#         return np.array([0.0], dtype=np.float32)
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0]  # (nh, T, N)
#     return a[head, :, feat].detach().float().cpu().numpy()

# def read_list_until_end(prompt, min_n=1):
#     print(prompt)
#     out = []
#     while True:
#         s = input("> ").strip()
#         if s.upper() == "END":
#             break
#         if s:
#             out.append(s)
#     if len(out) < min_n:
#         raise ValueError(f"Need at least {min_n} items.")
#     return out

# def short(s, n=28):
#     s = s.replace("\n", " ")
#     return s if len(s) <= n else s[:n-3] + "..."

# # ----------------------------
# # 0.5) Optional reset prompt
# # ----------------------------
# do_reset = input("Reset caches before running? (y/n)\n> ").strip().lower() == "y"
# if do_reset:
#     hard_reset_state()

# # ----------------------------
# # 4) DEMO MODE (Top-K neuron IDs)
# # ----------------------------
# print("\n=== DEMO MODE ===")
# sentence = input("Enter an English word/sentence:\n> ").strip()

# print("\nAvailable languages:")
# for k in ALL_LANGS:
#     print(" -", k)

# langs_raw = input("\nChoose languages (comma-separated, blank=all):\n> ").strip()
# LANGS = [l.strip() for l in langs_raw.split(",") if l.strip() in ALL_LANGS]
# if len(LANGS) == 0:
#     LANGS = list(ALL_LANGS.keys())

# translations = {lang: translate(lang, sentence) for lang in LANGS}

# print("\nTranslations:")
# print("EN:", sentence)
# for lang, txt in translations.items():
#     print(f"{lang[:2].upper()}: {txt}")

# @torch.no_grad()
# def topk_list_fulltext(text, layer_idx, k=TOPK_PRINT):
#     x = ids_for_text(text).to(device)
#     _, _, sparse = model(x, return_sparse=True)
#     a = sparse[layer_idx][0].mean(dim=1)  # (nh,N)
#     flat = a.reshape(-1)
#     k = min(k, flat.numel())
#     vals, idxs = torch.topk(flat, k)
#     out = []
#     for v, ix in zip(vals.tolist(), idxs.tolist()):
#         head = ix // N
#         feat = ix % N
#         gid = neuron_id(layer_idx, head, feat, cfg.n_head, N)
#         out.append((gid, (layer_idx, head, feat), float(v)))
#     return out

# for L in LAYERS:
#     print(f"\nTOP-{TOPK_PRINT} NEURONS — Layer {L}")
#     print(f"{'TEXT':<6} {'NEURON_ID':<10} {'(layer,head,feat)':<18} {'ACT'}")
#     print("-"*60)
#     for gid, trip, act in topk_list_fulltext(sentence, L):
#         print(f"{'EN':<6} {gid:<10} {str(trip):<18} {act:.3f}")
#     for lang, txt in translations.items():
#         for gid, trip, act in topk_list_fulltext(txt, L):
#             print(f"{lang[:2].upper():<6} {gid:<10} {str(trip):<18} {act:.3f}")

# # ----------------------------
# # 5) DATASET MODE (concepts)
# # ----------------------------
# print("\n=== DATASET MODE (monosemantic test) ===")
# concepts = read_list_until_end("Enter CONCEPT words (>=5). Type END to finish:", min_n=5)

# concept_trans = []
# for w in concepts:
#     trans = {"EN": w}
#     for lang in LANGS:
#         trans[lang] = translate(lang, w)
#     concept_trans.append(trans)

# # ----------------------------
# # 5.5) STRONG NEG SETS (3 types)
# # ----------------------------
# print("\n=== NEG SETS (3 groups) ===")
# print("NEG-A: Unrelated words (>=5)")
# NEG_A = read_list_until_end("Enter NEG-A (unrelated). Type END:", min_n=5)

# print("\nNEG-B: Spelling/byte-similar controls (>=5)  [important for byte models]")
# NEG_B = read_list_until_end("Enter NEG-B (similar spelling). Type END:", min_n=5)

# print("\nNEG-C: Same-domain-but-different concept (>=5)  [hardest test]")
# NEG_C = read_list_until_end("Enter NEG-C (same domain). Type END:", min_n=5)

# NEG_ALL = NEG_A + NEG_B + NEG_C
# print("\nNEG summary:")
# print("NEG-A:", NEG_A)
# print("NEG-B:", NEG_B)
# print("NEG-C:", NEG_C)

# # ----------------------------
# # 6) Actual vs Baseline Jaccard (STRONG baseline with shuffles)
# # ----------------------------
# def compute_actual(layer_idx):
#     names = [ct["EN"] for ct in concept_trans]
#     actual = []
#     for trans in concept_trans:
#         S_en = topk_set_for_text(trans["EN"], layer_idx, TOPK_SET)
#         js = []
#         for lang in LANGS:
#             S_tr = topk_set_for_text(trans[lang], layer_idx, TOPK_SET)
#             js.append(jaccard(S_en, S_tr))
#         actual.append(float(np.mean(js)))
#     return names, np.array(actual)

# def compute_shuffle_baseline(layer_idx, n_shuffles=20):
#     # for each shuffle, permute translations among concepts
#     n = len(concept_trans)
#     base = []
#     for _ in range(n_shuffles):
#         perm = np.random.permutation(n)
#         vals = []
#         for i in range(n):
#             S_en = topk_set_for_text(concept_trans[i]["EN"], layer_idx, TOPK_SET)
#             js = []
#             for lang in LANGS:
#                 S_wrong = topk_set_for_text(concept_trans[perm[i]][lang], layer_idx, TOPK_SET)
#                 js.append(jaccard(S_en, S_wrong))
#             vals.append(float(np.mean(js)))
#         base.append(np.mean(vals))
#     return np.array(base)

# results = {}
# for L in LAYERS:
#     names, actual_arr = compute_actual(L)
#     base_dist = compute_shuffle_baseline(L, n_shuffles=20)

#     results[L] = (names, actual_arr, base_dist)

#     print(f"\nLayer {L}: actual_mean={actual_arr.mean():.3f}  baseline_mean={base_dist.mean():.3f}  baseline_std={base_dist.std():.3f}")

#     # Plot: actual per concept + baseline mean line
#     x = np.arange(len(names))
#     plt.figure(figsize=(12,4))
#     plt.bar(x, actual_arr, label="Actual (EN vs its translations)")
#     plt.axhline(base_dist.mean(), linestyle="--", label="Baseline mean (shuffled)")
#     plt.axhspan(base_dist.mean()-base_dist.std(), base_dist.mean()+base_dist.std(), alpha=0.2, label="Baseline ±1σ")
#     plt.xticks(x, names, rotation=45, ha="right")
#     plt.ylim(0, 1)
#     plt.ylabel("Jaccard overlap")
#     plt.title(f"Layer {L}: Actual vs Shuffle Baseline (TopK={TOPK_SET})")
#     plt.legend()
#     plt.tight_layout()
#     plt.show()

# # ----------------------------
# # 7) Best shared neuron per concept + stronger selectivity metrics
# # ----------------------------
# @torch.no_grad()
# def best_shared_neuron_for_concept(i, layer_idx):
#     trans = concept_trans[i]
#     POS = [trans["EN"]] + [trans[lang] for lang in LANGS]

#     # candidates: intersection of topK across POS + EN topK fallback
#     inter = None
#     for t in POS:
#         S = topk_set_for_text(t, layer_idx, TOPK_SET)
#         inter = S if inter is None else (inter & S)
#     if inter is None:
#         inter = set()

#     S_en = topk_set_for_text(trans["EN"], layer_idx, TOPK_SET)
#     cand = list((set(inter) | set(S_en)))[:MAX_CAND]
#     if not cand:
#         return None

#     best = None
#     for gid in cand:
#         pos_vals = np.array([activation_gid_mean(t, layer_idx, gid) for t in POS], dtype=np.float64)
#         neg_vals = np.array([activation_gid_mean(t, layer_idx, gid) for t in NEG_ALL], dtype=np.float64)

#         pos_mean = float(pos_vals.mean())
#         neg_mean = float(neg_vals.mean())
#         neg_std  = float(neg_vals.std() + 1e-6)

#         sel = pos_mean - neg_mean
#         z   = sel / neg_std
#         sep = float(pos_vals.min() - neg_vals.max())  # > 0 is strong separation

#         score = (z, sep, sel)  # prioritize high z, then separation, then sel
#         if (best is None) or (score > best["score"]):
#             best = {
#                 "gid": gid,
#                 "pos_vals": pos_vals,
#                 "neg_vals": neg_vals,
#                 "pos_mean": pos_mean,
#                 "neg_mean": neg_mean,
#                 "neg_std": neg_std,
#                 "sel": sel,
#                 "z": z,
#                 "sep": sep,
#                 "score": score,
#             }
#     return best

# TOP_SHOW = min(3, len(concept_trans))

# for L in LAYERS:
#     names, actual_arr, base_dist = results[L]
#     margins = actual_arr - base_dist.mean()
#     order = np.argsort(-margins)[:TOP_SHOW]

#     print(f"\n=== Layer {L}: top {TOP_SHOW} concepts by (actual - baseline_mean) ===")

#     for idx in order:
#         concept = concept_trans[idx]["EN"]
#         best = best_shared_neuron_for_concept(idx, L)
#         if best is None:
#             print(f"Concept '{concept}': no candidates.")
#             continue

#         gid = best["gid"]
#         pos_vals = best["pos_vals"]
#         neg_vals = best["neg_vals"]

#         # PASS rules (you can tune)
#         pass_sep = best["sep"] > 0.0
#         pass_z   = best["z"] >= 2.0
#         verdict = "PASS " if (pass_sep and pass_z) else "WEAK "

#         print(f"\nConcept: {concept} | neuron={gid} decode={decode_gid(gid)} | {verdict}")
#         print(f"  pos_mean={best['pos_mean']:.3f}  neg_mean={best['neg_mean']:.3f}  sel={best['sel']:.3f}")
#         print(f"  neg_std={best['neg_std']:.3f}  z={best['z']:.2f}  sep(minPOS-maxNEG)={best['sep']:.3f}")

#         # Bar plot (POS vs NEG groups)
#         POS_texts  = [concept_trans[idx]["EN"]] + [concept_trans[idx][lang] for lang in LANGS]
#         POS_labels = ["EN"] + [lang[:2].upper() for lang in LANGS]

#         NEG_labels = (
#             [f"A{i+1}" for i in range(len(NEG_A))] +
#             [f"B{i+1}" for i in range(len(NEG_B))] +
#             [f"C{i+1}" for i in range(len(NEG_C))]
#         )

#         all_labels = POS_labels + NEG_labels
#         all_vals = np.concatenate([pos_vals, neg_vals], axis=0)

#         plt.figure(figsize=(14,4))
#         plt.bar(np.arange(len(all_labels)), all_vals)
#         plt.xticks(np.arange(len(all_labels)), all_labels, rotation=0)
#         plt.ylabel("Mean activation")
#         plt.title(f"Layer {L} neuron {gid} | concept='{concept}' | {verdict}")
#         plt.tight_layout()
#         plt.show()

#         # Boxplot (POS vs NEG)
#         plt.figure(figsize=(6,4))
#         plt.boxplot([pos_vals, neg_vals], tick_labels=["POS", "NEG(all)"], showmeans=True)
#         plt.ylabel("Mean activation")
#         plt.title(f"Layer {L} neuron {gid}: POS vs NEG")
#         plt.tight_layout()
#         plt.show()

# # ----------------------------
# # 8) EXPLAINABLE GIF (EN + translation + 2 NEG controls) with BYTES shown
# # ----------------------------
# make_gif = input("\nMake an EXPLAINABLE GIF? (y/n)\n> ").strip().lower() == "y"
# if make_gif:
#     import imageio.v2 as imageio
#     from IPython.display import HTML, display

#     print("Available layers:", LAYERS)
#     L_in = input(f"Which layer? (default {LAYERS[0]})\n> ").strip()
#     L = int(L_in) if L_in.isdigit() else LAYERS[0]
#     if L not in LAYERS:
#         L = LAYERS[0]

#     concept_names = [ct["EN"] for ct in concept_trans]
#     print("\nConcepts:", concept_names)
#     chosen = input("Which EN concept for GIF? (type exactly)\n> ").strip().lower()
#     idx_map = {ct["EN"].lower(): i for i, ct in enumerate(concept_trans)}
#     if chosen not in idx_map:
#         print("Not matched; using first concept:", concept_names[0])
#         best_idx = 0
#     else:
#         best_idx = idx_map[chosen]

#     print("Chosen LANGS:", LANGS)
#     other_lang = input(f"Which translation language for GIF? (default {LANGS[0]})\n> ").strip()
#     if other_lang not in LANGS:
#         other_lang = LANGS[0]

#     best = best_shared_neuron_for_concept(best_idx, L)
#     if best is None:
#         print("No neuron found.")
#     else:
#         gid = best["gid"]
#         concept = concept_trans[best_idx]["EN"]
#         en_text = concept_trans[best_idx]["EN"]
#         tr_text = concept_trans[best_idx][other_lang]

#         # choose 2 NEG controls: 1 spelling-control + 1 unrelated
#         neg1 = NEG_B[0] if len(NEG_B) else NEG_ALL[0]
#         neg2 = NEG_A[0] if len(NEG_A) else (NEG_ALL[1] if len(NEG_ALL) > 1 else NEG_ALL[0])

#         seqs = [
#             ("EN", en_text),
#             (other_lang[:2].upper(), tr_text),
#             ("NEG-B", neg1),
#             ("NEG-A", neg2),
#         ]

#         print("\n[gif] Will animate:")
#         print(f"  concept: {concept}")
#         print(f"  layer:   {L}")
#         print(f"  neuron:  {gid} decode={decode_gid(gid)}")
#         for tag, t in seqs:
#             print(f"   - {tag}: {t}")

#         gif_path = "bdh_activation_explain.gif"
#         out_frames = []

#         def to_bytes_preview(text, max_len=48):
#             b = text.encode("utf-8", errors="replace")
#             # show bytes as printable-ish: hex
#             hexs = [f"{x:02x}" for x in b[:max_len]]
#             if len(b) > max_len:
#                 hexs.append("..")
#             return b, " ".join(hexs)

#         def fig_to_rgb(fig):
#             fig.canvas.draw()
#             buf = np.asarray(fig.canvas.buffer_rgba())  # (H,W,4)
#             return buf[:, :, :3].copy()

#         # build per-seq curves first so we can share y-scale
#         curves = []
#         byte_views = []
#         for tag, text in seqs:
#             vals = activation_gid_over_positions(text, L, gid)
#             curves.append(vals)
#             b, hx = to_bytes_preview(text, max_len=48)
#             byte_views.append((b, hx))

#         global_max = max([float(v.max()) for v in curves]) if curves else 1.0
#         y_max = max(global_max * 1.15, 1e-3)

#         # GIF parameters
#         fps = 12
#         max_frames = 140  # keep manageable
#         step = 1

#         for (tag, text), vals, (b, hx) in zip(seqs, curves, byte_views):
#             T = len(vals)
#             if T <= 2:
#                 continue

#             # frame sampling so long texts still animate
#             if T > max_frames:
#                 step = int(np.ceil(T / max_frames))
#             else:
#                 step = 1

#             for t in range(2, T, step):
#                 fig = plt.figure(figsize=(10, 4))

#                 # activation plot
#                 ax1 = plt.subplot(2, 1, 1)
#                 ax1.plot(np.arange(T), vals, alpha=0.25, linewidth=1)  # faint full curve
#                 ax1.plot(np.arange(t), vals[:t], linewidth=2)          # growing curve
#                 ax1.axvline(t-1, linestyle="--", linewidth=1)
#                 ax1.set_xlim(0, T-1)
#                 ax1.set_ylim(0, y_max)
#                 ax1.set_ylabel("Activation")
#                 ax1.set_title(f"Layer {L} neuron {gid} | {tag}: '{short(text, 40)}' | byte index {t-1}/{T-1}")

#                 # bytes display (hex + highlight current byte)
#                 ax2 = plt.subplot(2, 1, 2)
#                 ax2.axis("off")
#                 cur = min(t-1, len(b)-1)
#                 # show a small window around cur byte
#                 left = max(0, cur-16)
#                 right = min(len(b), cur+16)
#                 window = b[left:right]
#                 window_hex = [f"{x:02x}" for x in window]

#                 # highlight current byte in the window using brackets
#                 hi = cur - left
#                 if 0 <= hi < len(window_hex):
#                     window_hex[hi] = "[" + window_hex[hi] + "]"

#                 ax2.text(
#                     0.01, 0.55,
#                     f"bytes(hex) around current:\n" + " ".join(window_hex),
#                     fontsize=10, family="monospace"
#                 )
#                 ax2.text(
#                     0.01, 0.05,
#                     f"full bytes preview:\n{hx}",
#                     fontsize=9, family="monospace", alpha=0.85
#                 )

#                 plt.tight_layout()
#                 out_frames.append(fig_to_rgb(fig))
#                 plt.close(fig)

#         imageio.mimsave(gif_path, out_frames, fps=fps)
#         print("Saved GIF:", gif_path)

#         # More reliable looping display than IPython.display.Image in some setups:
#         # embed as HTML <img> so browser loops it smoothly
#         with open(gif_path, "rb") as f:
#             data = f.read()
#         import base64
#         b64 = base64.b64encode(data).decode("utf-8")
#         display(HTML(f"<img src='data:image/gif;base64,{b64}' loop='infinite' />"))

# print("\nDONE.")


In [None]:
%%writefile app.py
# ============================================
# Streamlit BDH Phase-2 Probe UI
# ============================================

import os, re, random, base64
import numpy as np
import torch
import streamlit as st
import matplotlib.pyplot as plt
import importlib.util
import imageio.v2 as imageio

from transformers import MarianMTModel, MarianTokenizer

# ----------------------------
# App Config
# ----------------------------
st.set_page_config(page_title="BDH Phase-2 Probe", layout="wide")

# ----------------------------
# Constants / Defaults
# ----------------------------
SEED = 1337

ALL_LANGS = {
    "German":  "Helsinki-NLP/opus-mt-en-de",
    "French":  "Helsinki-NLP/opus-mt-en-fr",
    "Spanish": "Helsinki-NLP/opus-mt-en-es",
    "Italian": "Helsinki-NLP/opus-mt-en-it",
}

DEFAULT_BDH_SCRIPT_PATH = "/content/bdh_europarl_train_probe.py"
DEFAULT_CKPT_PATH = "checkpoints/bdh_europarl_bytes.pt"

# ----------------------------
# Utility
# ----------------------------
def short(s, n=28):
    s = s.replace("\n", " ")
    return s if len(s) <= n else s[:n-3] + "..."

def seed_all(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def split_lines(s: str):
    return [x.strip() for x in s.splitlines() if x.strip()]

# ----------------------------
# Session "Memory"
# ----------------------------
if "trans_cache" not in st.session_state:
    st.session_state.trans_cache = {}
if "seed" not in st.session_state:
    st.session_state.seed = SEED

def hard_reset_state(also_seed=True, clear_gpu=True):
    st.session_state.trans_cache = {}
    if clear_gpu and torch.cuda.is_available():
        torch.cuda.empty_cache()
    if also_seed:
        seed_all(st.session_state.seed)

# ----------------------------
# Cached loaders (fast)
# ----------------------------
@st.cache_resource(show_spinner=False)
def load_bdh_module(bdh_script_path: str):
    spec = importlib.util.spec_from_file_location("bdhmod", bdh_script_path)
    bdhmod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(bdhmod)
    return bdhmod

@st.cache_resource(show_spinner=False)
def load_bdh_model(bdh_script_path: str, ckpt_path: str):
    bdhmod = load_bdh_module(bdh_script_path)

    BDHConfig    = bdhmod.BDHConfig
    BDH          = bdhmod.BDH
    ids_for_text = bdhmod.ids_for_text
    neuron_id    = bdhmod.neuron_id

    device = "cuda" if torch.cuda.is_available() else "cpu"
    state = torch.load(ckpt_path, map_location=device)

    nh, D, N_tensor = state["encoder"].shape
    mult = int((N_tensor * nh) // D)

    cfg = BDHConfig(
        vocab_size=256, n_layer=6, n_embd=D, n_head=nh,
        mlp_internal_dim_multiplier=mult, dropout=0.0
    )
    model = BDH(cfg).to(device)
    model.load_state_dict(state, strict=True)
    model.eval()

    N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head
    return dict(
        bdhmod=bdhmod, cfg=cfg, model=model, device=device, N=N,
        BDHConfig=BDHConfig, BDH=BDH, ids_for_text=ids_for_text, neuron_id=neuron_id
    )

def translate(lang, text, device):
    cache = st.session_state.trans_cache
    if lang not in cache:
        tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
        mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
        cache[lang] = (tok, mdl)
    tok, mdl = cache[lang]
    batch = tok([text], return_tensors="pt", padding=True).to(device)
    out = mdl.generate(**batch, max_new_tokens=64)
    return tok.decode(out[0], skip_special_tokens=True)

# ----------------------------
# BDH math helpers
# ----------------------------
EPS = 1e-9

def jaccard(a: set, b: set) -> float:
    return len(a & b) / (len(a | b) + EPS)

def decode_gid(gid: int, cfg, N):
    per_layer = cfg.n_head * N
    layer = gid // per_layer
    rem = gid % per_layer
    head = rem // N
    feat = rem % N
    return layer, head, feat

@torch.no_grad()
def topk_set_for_text(text: str, layer_idx: int, k: int, ids_for_text, model, device, cfg, N, neuron_id):
    x = ids_for_text(text).to(device)
    _, _, sparse = model(x, return_sparse=True)
    a = sparse[layer_idx][0].mean(dim=1)  # (nh, N)
    flat = a.reshape(-1)
    k = min(k, flat.numel())
    _, idxs = torch.topk(flat, k)
    S = set()
    for ix in idxs.tolist():
        head = ix // N
        feat = ix % N
        S.add(neuron_id(layer_idx, head, feat, cfg.n_head, N))
    return S

@torch.no_grad()
def activation_gid_mean(text: str, layer_idx: int, gid: int, ids_for_text, model, device, cfg, N):
    layer, head, feat = decode_gid(gid, cfg, N)
    if layer != layer_idx:
        return 0.0
    x = ids_for_text(text).to(device)
    _, _, sparse = model(x, return_sparse=True)
    a = sparse[layer_idx][0]  # (nh, T, N)
    return float(a[head, :, feat].mean().item())

@torch.no_grad()
def activation_gid_over_positions(text: str, layer_idx: int, gid: int, ids_for_text, model, device, cfg, N):
    layer, head, feat = decode_gid(gid, cfg, N)
    if layer != layer_idx:
        return np.array([0.0], dtype=np.float32)
    x = ids_for_text(text).to(device)
    _, _, sparse = model(x, return_sparse=True)
    a = sparse[layer_idx][0]
    return a[head, :, feat].detach().float().cpu().numpy()

@torch.no_grad()
def topk_list_fulltext(text, layer_idx, k, ids_for_text, model, device, cfg, N, neuron_id):
    x = ids_for_text(text).to(device)
    _, _, sparse = model(x, return_sparse=True)
    a = sparse[layer_idx][0].mean(dim=1)  # (nh,N)
    flat = a.reshape(-1)
    k = min(k, flat.numel())
    vals, idxs = torch.topk(flat, k)
    out = []
    for v, ix in zip(vals.tolist(), idxs.tolist()):
        head = ix // N
        feat = ix % N
        gid = neuron_id(layer_idx, head, feat, cfg.n_head, N)
        out.append((gid, (layer_idx, head, feat), float(v)))
    return out

@torch.no_grad()
def best_shared_neuron_for_concept(concept_trans, idx, layer_idx, LANGS, TOPK_SET, MAX_CAND,
                                  NEG_ALL, ids_for_text, model, device, cfg, N, neuron_id):
    trans = concept_trans[idx]
    POS = [trans["EN"]] + [trans[lang] for lang in LANGS]

    # candidates: inter(POS topK) + EN topK
    inter = None
    for t in POS:
        S = topk_set_for_text(t, layer_idx, TOPK_SET, ids_for_text, model, device, cfg, N, neuron_id)
        inter = S if inter is None else (inter & S)
    if inter is None:
        inter = set()

    S_en = topk_set_for_text(trans["EN"], layer_idx, TOPK_SET, ids_for_text, model, device, cfg, N, neuron_id)
    cand = list((set(inter) | set(S_en)))[:MAX_CAND]
    if not cand:
        return None

    best = None
    for gid in cand:
        pos_vals = np.array([activation_gid_mean(t, layer_idx, gid, ids_for_text, model, device, cfg, N) for t in POS], dtype=np.float64)
        neg_vals = np.array([activation_gid_mean(t, layer_idx, gid, ids_for_text, model, device, cfg, N) for t in NEG_ALL], dtype=np.float64)

        pos_mean = float(pos_vals.mean())
        neg_mean = float(neg_vals.mean())
        neg_std  = float(neg_vals.std() + 1e-6)

        sel = pos_mean - neg_mean
        z   = sel / neg_std
        sep = float(pos_vals.min() - neg_vals.max())  # > 0 strong

        score = (z, sep, sel)
        if (best is None) or (score > best["score"]):
            best = dict(gid=gid, pos_vals=pos_vals, neg_vals=neg_vals,
                        pos_mean=pos_mean, neg_mean=neg_mean, neg_std=neg_std,
                        sel=sel, z=z, sep=sep, score=score)
    return best

# ----------------------------
# UI
# ----------------------------
st.title("BDH Phase-2 Probe — Streamlit UI")

with st.sidebar:
    st.header("Paths")
    bdh_script_path = st.text_input("BDH script path", DEFAULT_BDH_SCRIPT_PATH)
    ckpt_path = st.text_input("Checkpoint path", DEFAULT_CKPT_PATH)

    st.header("Run settings")
    st.session_state.seed = st.number_input("Seed", value=SEED, step=1)
    seed_all(st.session_state.seed)

    TOPK_PRINT = st.slider("Top-K print (demo)", 1, 20, 5)
    TOPK_SET = st.slider("Top-K set (probe stability)", 50, 500, 200, step=50)
    MAX_CAND = st.slider("Max candidates", 100, 2000, 800, step=100)

    st.header("Reset")
    colR1, colR2 = st.columns(2)
    if colR1.button("Reset cache"):
        hard_reset_state(also_seed=False, clear_gpu=False)
        st.success("Cleared translation cache.")
    if colR2.button("Reset + reseed"):
        hard_reset_state(also_seed=True, clear_gpu=True)
        st.success("Cleared cache + reseeded (+ cleared GPU cache if available).")

# Load model
try:
    pack = load_bdh_model(bdh_script_path, ckpt_path)
except Exception as e:
    st.error(f"Could not load BDH model/module. Check paths.\n\n{e}")
    st.stop()

bdhmod = pack["bdhmod"]
cfg = pack["cfg"]
model = pack["model"]
device = pack["device"]
N = pack["N"]
ids_for_text = pack["ids_for_text"]
neuron_id = pack["neuron_id"]

st.caption(f"Device: {device} | Layers: {cfg.n_layer} | Embd: {cfg.n_embd} | Heads: {cfg.n_head} | N/head: {N}")

tabs = st.tabs(["Demo", "Dataset probe", "GIF builder"])

# ----------------------------
# Demo tab
# ----------------------------
with tabs[0]:
    st.subheader("Demo: Top neurons for a text and its translations")

    sentence = st.text_input("English word/sentence", value="lion")
    chosen_langs = st.multiselect("Languages", options=list(ALL_LANGS.keys()), default=["German"])

    if st.button("Run demo"):
        if not chosen_langs:
            chosen_langs = list(ALL_LANGS.keys())

        translations = {}
        with st.spinner("Translating..."):
            for lang in chosen_langs:
                translations[lang] = translate(lang, sentence, device)

        st.write("**Translations**")
        st.write("EN:", sentence)
        for lang, txt in translations.items():
            st.write(f"{lang[:2].upper()}: {txt}")

        layers = [4, 5]
        layers = [L for L in layers if 0 <= L < cfg.n_layer]

        for L in layers:
            st.markdown(f"### Layer {L}")
            rows = []
            for gid, trip, act in topk_list_fulltext(sentence, L, TOPK_PRINT, ids_for_text, model, device, cfg, N, neuron_id):
                rows.append(("EN", gid, trip, act))
            for lang, txt in translations.items():
                for gid, trip, act in topk_list_fulltext(txt, L, TOPK_PRINT, ids_for_text, model, device, cfg, N, neuron_id):
                    rows.append((lang[:2].upper(), gid, trip, act))

            st.dataframe(
                [{"TEXT": r[0], "NEURON_ID": r[1], "(layer,head,feat)": str(r[2]), "ACT": round(r[3], 4)} for r in rows],
                use_container_width=True
            )

# ----------------------------
# Dataset Probe tab
# ----------------------------
with tabs[1]:
    st.subheader("Dataset probe: Actual vs shuffle baseline + best neuron proof")

    st.markdown("Enter one word per line.")
    concept_text = st.text_area("CONCEPT words (>=5)", value="dog\nlion\nplant\nrain\nfire", height=130)

    chosen_langs = st.multiselect("Languages (probe)", options=list(ALL_LANGS.keys()), default=["German"])

    st.markdown("### NEG sets (fresh each run)")
    neg_a_text = st.text_area("NEG-A (unrelated, >=5)", value="laptop\npune\ndance\nmountain\nriver", height=120)
    neg_b_text = st.text_area("NEG-B (spelling/byte-similar, >=5)", value="li0n\nl1on\nlioness\nlions\nlioning", height=120)
    neg_c_text = st.text_area("NEG-C (same domain but different, >=5)", value="tiger\ncheetah\nleopard\npanther\nhyena", height=120)

    n_shuffles = st.slider("Baseline shuffles", 5, 60, 20, step=5)
    top_show = st.slider("Show top concepts", 1, 6, 3)

    if st.button("Run dataset probe"):
        concepts = split_lines(concept_text)
        NEG_A = split_lines(neg_a_text)
        NEG_B = split_lines(neg_b_text)
        NEG_C = split_lines(neg_c_text)

        if len(concepts) < 5:
            st.error("Need at least 5 concept words.")
            st.stop()
        if len(NEG_A) < 5 or len(NEG_B) < 5 or len(NEG_C) < 5:
            st.error("Each NEG group must have at least 5 words.")
            st.stop()

        if not chosen_langs:
            chosen_langs = list(ALL_LANGS.keys())
        LANGS = chosen_langs

        with st.spinner("Translating concepts..."):
            concept_trans = []
            for w in concepts:
                trans = {"EN": w}
                for lang in LANGS:
                    trans[lang] = translate(lang, w, device)
                concept_trans.append(trans)

        NEG_ALL = NEG_A + NEG_B + NEG_C

        def compute_actual(layer_idx):
            names = [ct["EN"] for ct in concept_trans]
            actual = []
            for trans in concept_trans:
                S_en = topk_set_for_text(trans["EN"], layer_idx, TOPK_SET, ids_for_text, model, device, cfg, N, neuron_id)
                js = []
                for lang in LANGS:
                    S_tr = topk_set_for_text(trans[lang], layer_idx, TOPK_SET, ids_for_text, model, device, cfg, N, neuron_id)
                    js.append(jaccard(S_en, S_tr))
                actual.append(float(np.mean(js)))
            return names, np.array(actual)

        def compute_shuffle_baseline(layer_idx, n_shuffles=20):
            n = len(concept_trans)
            base = []
            for _ in range(n_shuffles):
                perm = np.random.permutation(n)
                vals = []
                for i in range(n):
                    S_en = topk_set_for_text(concept_trans[i]["EN"], layer_idx, TOPK_SET, ids_for_text, model, device, cfg, N, neuron_id)
                    js = []
                    for lang in LANGS:
                        S_wrong = topk_set_for_text(concept_trans[perm[i]][lang], layer_idx, TOPK_SET, ids_for_text, model, device, cfg, N, neuron_id)
                        js.append(jaccard(S_en, S_wrong))
                    vals.append(float(np.mean(js)))
                base.append(np.mean(vals))
            return np.array(base)

        layers = [4, 5]
        layers = [L for L in layers if 0 <= L < cfg.n_layer]

        for L in layers:
            names, actual_arr = compute_actual(L)
            base_dist = compute_shuffle_baseline(L, n_shuffles=n_shuffles)

            st.markdown(f"## Layer {L}")
            st.write(
                f"actual_mean={actual_arr.mean():.3f} | baseline_mean={base_dist.mean():.3f} | baseline_std={base_dist.std():.3f}"
            )

            # Plot
            fig = plt.figure(figsize=(10, 3.2))
            x = np.arange(len(names))
            plt.bar(x, actual_arr, label="Actual (EN vs translations)")
            plt.axhline(base_dist.mean(), linestyle="--", label="Baseline mean (shuffled)")
            plt.axhspan(base_dist.mean()-base_dist.std(), base_dist.mean()+base_dist.std(), alpha=0.2, label="Baseline ±1σ")
            plt.xticks(x, names, rotation=45, ha="right")
            plt.ylim(0, 1)
            plt.ylabel("Jaccard overlap")
            plt.title(f"Layer {L}: Actual vs Shuffle Baseline (TopK={TOPK_SET})")
            plt.legend()
            plt.tight_layout()
            st.pyplot(fig)
            plt.close(fig)

            # Best neuron proof for top margins
            margins = actual_arr - base_dist.mean()
            order = np.argsort(-margins)[:min(top_show, len(names))]

            for idx in order:
                concept = concept_trans[idx]["EN"]
                best = best_shared_neuron_for_concept(
                    concept_trans, idx, L, LANGS, TOPK_SET, MAX_CAND, NEG_ALL,
                    ids_for_text, model, device, cfg, N, neuron_id
                )
                if best is None:
                    st.write(f"Concept '{concept}': no candidates.")
                    continue

                gid = best["gid"]
                verdict = "PASS" if (best["sep"] > 0.0 and best["z"] >= 2.0) else "WEAK"

                st.markdown(f"### Concept: `{concept}` | neuron `{gid}` decode={decode_gid(gid, cfg, N)} | **{verdict}**")
                st.code(
                    f"pos_mean={best['pos_mean']:.3f}  neg_mean={best['neg_mean']:.3f}  sel={best['sel']:.3f}\n"
                    f"neg_std={best['neg_std']:.3f}  z={best['z']:.2f}  sep(minPOS-maxNEG)={best['sep']:.3f}"
                )

                POS_labels = ["EN"] + [lang[:2].upper() for lang in LANGS]
                NEG_labels = (
                    [f"NEG-A:{w}" for w in NEG_A] +
                    [f"NEG-B:{w}" for w in NEG_B] +
                    [f"NEG-C:{w}" for w in NEG_C]
                )
                all_labels = POS_labels + NEG_labels
                all_vals = np.concatenate([best["pos_vals"], best["neg_vals"]], axis=0)

                fig2 = plt.figure(figsize=(12, 3.2))
                plt.bar(np.arange(len(all_labels)), all_vals)
                plt.xticks(np.arange(len(all_labels)), [short(x, 18) for x in all_labels], rotation=45, ha="right")
                plt.ylabel("Mean activation")
                plt.title(f"Layer {L} neuron {gid} | concept='{concept}' | {verdict}")
                plt.tight_layout()
                st.pyplot(fig2)
                plt.close(fig2)

# ----------------------------
# GIF Builder tab
# ----------------------------
with tabs[2]:
    st.subheader("GIF builder (explainable): EN + translation + 2 NEG controls")

    st.markdown("This makes a GIF that animates activation over byte positions and shows the current byte window in hex.")

    layers = [4, 5]
    layers = [L for L in layers if 0 <= L < cfg.n_layer]
    L = st.selectbox("Layer", layers, index=0)

    concept_text = st.text_area("Concepts (one per line, used for selection)", value="dog\nlion\nplant\nrain\nfire", height=120)
    concepts = split_lines(concept_text)
    if not concepts:
        concepts = ["lion"]

    concept_choice = st.selectbox("Choose EN concept", concepts, index=0)

    chosen_lang = st.selectbox("Translation language", list(ALL_LANGS.keys()), index=0)

    # NEG selectors (user-controlled)
    neg_a = st.text_input("NEG-A word (unrelated)", value="laptop")
    neg_b = st.text_input("NEG-B word (spelling/byte control)", value="li0n")

    TOPK_SET_gif = st.slider("TopK for selecting best neuron (GIF)", 50, 500, 200, step=50)
    MAX_FRAMES = st.slider("Max frames per sequence", 60, 250, 140, step=10)
    FPS = st.slider("FPS", 5, 20, 12)

    if st.button("Make GIF"):
        # Build minimal concept_trans (translate only what we need)
        LANGS = [chosen_lang]
        concept_trans = []
        for w in concepts:
            trans = {"EN": w, chosen_lang: translate(chosen_lang, w, device)}
            concept_trans.append(trans)

        idx_map = {ct["EN"].lower(): i for i, ct in enumerate(concept_trans)}
        idx = idx_map.get(concept_choice.lower(), 0)

        NEG_ALL = [neg_a, neg_b]  # only for scoring selection in GIF builder

        best = best_shared_neuron_for_concept(
            concept_trans, idx, L, LANGS, TOPK_SET_gif, 800, NEG_ALL,
            ids_for_text, model, device, cfg, N, neuron_id
        )

        if best is None:
            st.error("No neuron found for this concept/layer.")
            st.stop()

        gid = best["gid"]
        en_text = concept_trans[idx]["EN"]
        tr_text = concept_trans[idx][chosen_lang]

        seqs = [
            ("EN", en_text),
            (chosen_lang[:2].upper(), tr_text),
            ("NEG-A", neg_a),
            ("NEG-B", neg_b),
        ]

        st.write("Neuron:", gid, "decode:", decode_gid(gid, cfg, N))
        st.write("Animating:", seqs)

        curves = [activation_gid_over_positions(text, L, gid, ids_for_text, model, device, cfg, N) for _, text in seqs]
        global_max = max(float(v.max()) for v in curves) if curves else 1.0
        y_max = max(global_max * 1.15, 1e-3)

        out_frames = []

        def fig_to_rgb(fig):
            fig.canvas.draw()
            buf = np.asarray(fig.canvas.buffer_rgba())
            return buf[:, :, :3].copy()

        for (tag, text), vals in zip(seqs, curves):
            T = len(vals)
            if T <= 2:
                continue
            step = int(np.ceil(T / MAX_FRAMES)) if T > MAX_FRAMES else 1

            b = text.encode("utf-8", errors="replace")

            for t in range(2, T, step):
                fig = plt.figure(figsize=(10, 4))

                ax1 = plt.subplot(2, 1, 1)
                ax1.plot(np.arange(T), vals, alpha=0.25, linewidth=1)
                ax1.plot(np.arange(t), vals[:t], linewidth=2)
                ax1.axvline(t-1, linestyle="--", linewidth=1)
                ax1.set_xlim(0, T-1)
                ax1.set_ylim(0, y_max)
                ax1.set_ylabel("Activation")
                ax1.set_title(f"Layer {L} neuron {gid} | {tag}: '{short(text, 40)}' | byte index {t-1}/{T-1}")

                ax2 = plt.subplot(2, 1, 2)
                ax2.axis("off")

                cur = min(t-1, len(b)-1)
                left = max(0, cur-16)
                right = min(len(b), cur+16)
                window = b[left:right]
                window_hex = [f"{x:02x}" for x in window]
                hi = cur - left
                if 0 <= hi < len(window_hex):
                    window_hex[hi] = "[" + window_hex[hi] + "]"

                ax2.text(0.01, 0.55, "bytes(hex) around current:\n" + " ".join(window_hex),
                         fontsize=10, family="monospace")

                plt.tight_layout()
                out_frames.append(fig_to_rgb(fig))
                plt.close(fig)

        gif_path = "bdh_activation_explain.gif"
        imageio.mimsave(gif_path, out_frames, fps=FPS)

        st.success(f"Saved GIF: {gif_path}")

        # Display looping GIF reliably
        with open(gif_path, "rb") as f:
            b64 = base64.b64encode(f.read()).decode("utf-8")
        st.markdown(f"<img src='data:image/gif;base64,{b64}' loop='infinite' />", unsafe_allow_html=True)

st.caption("Tip: If translations feel 'stuck', hit Reset cache in the sidebar. NEG words are never stored—only what you type is used.")


In [None]:
from pyngrok import ngrok
ngrok.set_auth_token(39F7ybX9kasKCYRs2VZhiT419kA_3sxai6e28Ya5Xix2aKFxr)


In [None]:
import subprocess, time
from pyngrok import ngrok
from IPython.display import IFrame, display

process = subprocess.Popen(
    ["streamlit", "run", "app.py", "--server.port", "8501", "--server.headless", "true"]
)

time.sleep(3)

public_url = ngrok.connect(8501).public_url
print(" Streamlit running at:", public_url)

display(IFrame(public_url, width=1200, height=750))
