In [None]:
%%writefile bdh_europarl_train_probe.py
#!/usr/bin/env python3

import argparse
import math
import os
import random
import tarfile
import urllib.request
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================
# 1) BDH MODEL (byte-level vocab)
# ============================================================

@dataclass
class BDHConfig:
    n_layer: int = 6
    n_embd: int = 256
    dropout: float = 0.1
    n_head: int = 4
    mlp_internal_dim_multiplier: int = 128
    vocab_size: int = 256  # byte-level IDs in [0..255]


def get_freqs(n: int, theta: float, dtype: torch.dtype) -> torch.Tensor:
    """
    Rope-like frequency schedule used inside BDH attention.

    Implementation note:
    - quantize() keeps frequency steps coarse (as in the original BDH-style code).
    - returned tensor has shape (n,) and dtype=float-like.
    """
    def quantize(t: torch.Tensor, q: int = 2) -> torch.Tensor:
        return (t / q).floor() * q

    return (
        1.0
        / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n))
        / (2 * math.pi)
    )


class Attention(nn.Module):
    """
    BDH attention block (non-softmax, associative accumulation).

    Key difference vs Transformer:
    - No softmax normalization.
    - Uses a causal triangular mask via tril().
    """
    def __init__(self, config: BDHConfig):
        super().__init__()
        nh = config.n_head
        D = config.n_embd
        N = config.mlp_internal_dim_multiplier * D // nh

        # Buffer so it moves with the module and is saved in state_dict.
        self.freqs = torch.nn.Buffer(
            get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N)
        )

    @staticmethod
    def phases_cos_sin(phases: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        phases = (phases % 1) * (2 * math.pi)
        return torch.cos(phases), torch.sin(phases)

    @staticmethod
    def rope(phases: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # rotate pairs (even, odd) -> (-odd, even)
        v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())
        phases_cos, phases_sin = Attention.phases_cos_sin(phases)
        return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
        """
        Inputs:
          Q, K: (B, nh, T, N)
          V:    (B, 1,  T, D)  (as in your BDH implementation)

        Returns:
          (B, nh, T, D)
        """
        assert self.freqs.dtype == torch.float32
        assert K is Q, "This implementation expects K=Q (same sparse activations)."

        _, _, T, _ = Q.size()

        # Here we build per-position phases for RoPE. Shape becomes (1,1,T,1) * (1,1,1,N).
        r_phases = (
            torch.arange(0, T, device=self.freqs.device, dtype=self.freqs.dtype)
            .view(1, 1, -1, 1)
        ) * self.freqs

        # Apply RoPE to Q; then K is set equal to Q (as in your code).
        QR = self.rope(r_phases, Q)
        KR = QR

        # Causal mask: only attend to strictly previous tokens.
        # tril(diagonal=-1) removes diagonal (self-attend).
        scores = (QR @ KR.mT).tril(diagonal=-1)

        # No softmax: directly multiply scores with V (associative accumulation).
        return scores @ V


class BDH(nn.Module):
    """
    BDH language model with sparse, ReLU-gated intermediate features.

    During probing, we return x_sparse (post-ReLU) at each layer.
    """
    def __init__(self, config: BDHConfig):
        super().__init__()
        assert config.vocab_size is not None
        self.config = config

        nh = config.n_head
        D = config.n_embd
        N = config.mlp_internal_dim_multiplier * D // nh

        # Encoder/decoder weights are Parameters (as in your original).
        self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02))
        self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
        self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))

        self.attn = Attention(config)
        self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False)
        self.embed = nn.Embedding(config.vocab_size, D)
        self.drop = nn.Dropout(config.dropout)

        self.lm_head = nn.Parameter(torch.zeros((D, config.vocab_size)).normal_(std=0.02))

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        # Standard init for any Linear/Embedding if they exist.
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, return_sparse: bool = False):
        """
        idx: (B,T) integer byte IDs
        targets: (B,T) next-byte targets (shifted)
        return_sparse: if True, also return per-layer sparse activations (x_sparse)

        Returns:
          logits: (B,T,256)
          loss: scalar or None
          sparse_cache: list of tensors, one per layer, each (B,nh,T,N)  (only if return_sparse)
        """
        C = self.config
        B, T = idx.size()
        D = C.n_embd
        nh = C.n_head
        N = D * C.mlp_internal_dim_multiplier // nh

        # Here we embed bytes to (B,1,T,D). The singleton dim matches the original BDH shapes.
        x = self.embed(idx).unsqueeze(1)
        x = self.ln(x)

        sparse_cache = []  # store x_sparse per layer for probing

        for _layer in range(C.n_layer):
            # Here we project to latent features per head, then apply ReLU to get sparse activations.
            x_latent = x @ self.encoder                  # (B,nh,T,N)
            x_sparse = F.relu(x_latent)                  # (B,nh,T,N)

            if return_sparse:
                sparse_cache.append(x_sparse.detach())

            # Here attention mixes information over time using sparse Q/K and dense V (=x).
            yKV = self.attn(Q=x_sparse, K=x_sparse, V=x)  # (B,nh,T,D)
            yKV = self.ln(yKV)

            # Second sparse gating path (encoder_v), then elementwise product for “Hebbian-like” interaction.
            y_latent = yKV @ self.encoder_v              # (B,nh,T,N)
            y_sparse = F.relu(y_latent)                  # (B,nh,T,N)

            xy_sparse = x_sparse * y_sparse              # (B,nh,T,N)
            xy_sparse = self.drop(xy_sparse)

            # Decode back to model width and apply residual.
            yMLP = xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder  # (B,1,T,D)
            y = self.ln(yMLP)
            x = self.ln(x + y)

        logits = x.view(B, T, D) @ self.lm_head          # (B,T,256)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        if return_sparse:
            return logits, loss, sparse_cache
        return logits, loss


# ============================================================
# 2) EUROPARL PIPELINE (download -> extract -> build train.txt)
# ============================================================

EUROPARL_URLS = {
    # Europarl v7 pairs (English with EU languages)
    "de-en.tgz": "https://www.statmt.org/europarl/v7/de-en.tgz",
    "fr-en.tgz": "https://www.statmt.org/europarl/v7/fr-en.tgz",
    # Add more if needed:
    # "es-en.tgz": "https://www.statmt.org/europarl/v7/es-en.tgz",
}


def download(url: str, out_path: str) -> None:
    """
    Here we download the dataset archive if it is not already present locally.
    """
    out_dir = os.path.dirname(out_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    if os.path.exists(out_path):
        print(f"[download] exists: {out_path}")
        return

    print(f"[download] {url}")
    urllib.request.urlretrieve(url, out_path)
    print(f"[download] saved: {out_path}")


def _safe_extract_tgz(tgz_path: str, out_dir: str) -> None:
    """
    Safer tar extraction: prevents path traversal (files writing outside out_dir).
    """
    def is_within_directory(directory: str, target: str) -> bool:
        abs_directory = os.path.abspath(directory)
        abs_target = os.path.abspath(target)
        return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])

    with tarfile.open(tgz_path, "r:gz") as tar:
        for member in tar.getmembers():
            member_path = os.path.join(out_dir, member.name)
            if not is_within_directory(out_dir, member_path):
                raise RuntimeError(f"Unsafe tar member path detected: {member.name}")
        tar.extractall(out_dir)


def extract_tgz(tgz_path: str, out_dir: str) -> None:
    """
    Here we extract the downloaded .tgz once and drop a marker file to skip re-extraction.
    """
    marker = os.path.join(out_dir, ".extracted_" + os.path.basename(tgz_path))
    if os.path.exists(marker):
        print(f"[extract] already extracted: {tgz_path}")
        return

    print(f"[extract] {tgz_path}")
    os.makedirs(out_dir, exist_ok=True)
    _safe_extract_tgz(tgz_path, out_dir)

    with open(marker, "w", encoding="utf-8") as f:
        f.write("ok\n")


def iter_text_lines(path: str):
    """
    Generator over clean text lines:
    - strips whitespace
    - skips empty lines
    - skips Europarl tag-like lines (starting with '<')
    """
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith("<"):
                continue
            yield line


def build_train_txt(
    data_dir: str,
    out_txt: str,
    max_lines_per_file: int = 200_000,
    seed: int = 0,
) -> None:
    """
    Here we create a single training file by concatenating up to max_lines_per_file
    from each Europarl extracted file.

    Why:
    - BDH training here is byte-level and language-agnostic, so we can mix languages.
    - Mixing EN/DE/FR can encourage reuse of sparse features across languages.
    """
    random.seed(seed)

    files: List[str] = []
    for root, _, fnames in os.walk(data_dir):
        for fn in fnames:
            # Europarl extracted files typically: europarl-v7.de-en.en / .de etc.
            if "europarl" in fn and fn.endswith((".en", ".de", ".fr")):
                files.append(os.path.join(root, fn))

    if not files:
        raise RuntimeError(f"No europarl text files found under {data_dir}. Did extraction work?")

    random.shuffle(files)
    out_dir = os.path.dirname(out_txt)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    with open(out_txt, "w", encoding="utf-8") as out:
        for fp in files:
            c = 0
            for ln in iter_text_lines(fp):
                out.write(ln + "\n")
                c += 1
                if c >= max_lines_per_file:
                    break

    print(f"[dataset] wrote: {out_txt}")
    print(f"[dataset] included files: {len(files)}")


# ============================================================
# 3) BYTE DATASET (next-byte prediction)
# ============================================================

class ByteDataset(torch.utils.data.Dataset):
    """
    Byte-level language modeling dataset.

    Given a long byte array b[0..M-1], we sample windows:
      x = b[i : i+block_size]
      y = b[i+1 : i+block_size+1]

    This trains next-byte prediction: p(b[t+1] | b[:t]).
    """
    def __init__(self, train_txt: str, block_size: int, max_bytes: int):
        with open(train_txt, "r", encoding="utf-8", errors="ignore") as f:
            text = f.read()
        b = text.encode("utf-8", errors="ignore")[:max_bytes]
        self.data = torch.tensor(list(b), dtype=torch.long)
        self.block_size = block_size

    def __len__(self) -> int:
        return max(0, len(self.data) - self.block_size - 1)

    def __getitem__(self, i: int):
        x = self.data[i : i + self.block_size]
        y = self.data[i + 1 : i + self.block_size + 1]
        return x, y


# ============================================================
# 4) NEURON PROBING (sparse features -> top-k -> intersection)
# ============================================================

def neuron_id(layer: int, head: int, feat: int, nh: int, N: int) -> int:
    """
    Here we map (layer, head, feat) -> a single global integer ID.

    This makes it easy to compare neurons across layers and print them.
    """
    return layer * (nh * N) + head * N + feat


def ids_for_text(text: str) -> torch.Tensor:
    """
    Convert input text to byte IDs in [0..255] with shape (1,T).
    """
    b = text.encode("utf-8", errors="ignore")
    if len(b) == 0:
        b = b" "  # avoid empty inputs
    return torch.tensor([list(b)], dtype=torch.long)


@torch.no_grad()
def top_neurons_for_input(
    model: BDH,
    text: str,
    topk: int = 200,
    aggregate: str = "mean",  # "mean" over positions is more stable than "last"
) -> List[List[Tuple[int, float, int, int, int]]]:
    """
    Returns per-layer list of top sparse features.

    Output format:
      hits[layer] = [(global_neuron_id, activation, layer, head, feat), ...]
    sorted by activation descending.

    We aggregate across positions because byte-level tokenization spreads a word
    across multiple bytes, and we want a word-level-ish signal.
    """
    cfg = model.config
    nh = cfg.n_head
    D = cfg.n_embd
    N = D * cfg.mlp_internal_dim_multiplier // nh

    x = ids_for_text(text).to(next(model.parameters()).device)
    _, _, sparse_cache = model(x, return_sparse=True)

    hits_per_layer = []
    for layer, x_sparse in enumerate(sparse_cache):
        # x_sparse: (1, nh, T, N)
        a = x_sparse[0]  # (nh, T, N)

        if aggregate == "last":
            vec = a[:, -1, :]          # (nh, N)
        else:
            vec = a.mean(dim=1)        # (nh, N) mean over T

        flat = vec.reshape(-1)         # (nh*N,)
        k = min(topk, flat.numel())
        vals, idxs = torch.topk(flat, k=k)

        layer_hits = []
        for v, ix in zip(vals.tolist(), idxs.tolist()):
            head = ix // N
            feat = ix % N
            gid = neuron_id(layer, head, feat, nh, N)
            layer_hits.append((gid, float(v), layer, head, feat))

        hits_per_layer.append(layer_hits)

    return hits_per_layer


def shared_neurons_across_texts(
    all_hits: List[List[List[Tuple[int, float, int, int, int]]]],
    topk_intersection: int,
) -> List[List[int]]:
    """
    all_hits[word_i][layer] = list of top hits for that word/text and that layer.

    Here we compute an intersection:
      For each layer:
        Take the top topk_intersection neurons per input,
        intersect across all inputs,
        return the shared neuron IDs.

    Output:
      shared[layer] = sorted list of global neuron IDs.
    """
    n_layers = len(all_hits[0])
    shared: List[List[int]] = []

    for layer in range(n_layers):
        sets = []
        for wi in range(len(all_hits)):
            ids = set(h[0] for h in all_hits[wi][layer][:topk_intersection])
            sets.append(ids)
        common = set.intersection(*sets) if sets else set()
        shared.append(sorted(common))

    return shared


def decode_neuron(global_id: int, cfg: BDHConfig) -> Tuple[int, int, int]:
    """
    Inverse of neuron_id(): global integer -> (layer, head, feat).
    """
    nh = cfg.n_head
    D = cfg.n_embd
    N = D * cfg.mlp_internal_dim_multiplier // nh

    per_layer = nh * N
    layer = global_id // per_layer
    rem = global_id % per_layer
    head = rem // N
    feat = rem % N
    return layer, head, feat


# ============================================================
# 5) TRAIN / PROBE ENTRYPOINTS
# ============================================================

def train(args) -> None:
    """
    Train BDH as a byte-level LM on Europarl-mixed text.

    What we do here:
      - build ByteDataset over train.txt
      - optimize cross-entropy for next-byte prediction
      - save checkpoint for later probing
    """
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    print(f"[train] device={device}")

    # (Optional) Download + extract Europarl archives
    if args.download:
        os.makedirs(args.data_dir, exist_ok=True)
        for key in args.pairs:
            if key not in EUROPARL_URLS:
                raise ValueError(f"Unknown pair key '{key}'. Available: {list(EUROPARL_URLS.keys())}")
            tgz_path = os.path.join(args.data_dir, key)
            download(EUROPARL_URLS[key], tgz_path)
            extract_tgz(tgz_path, args.data_dir)

    # Build train.txt if missing or forced
    if not os.path.exists(args.train_txt) or args.rebuild_txt:
        build_train_txt(
            data_dir=args.data_dir,
            out_txt=args.train_txt,
            max_lines_per_file=args.max_lines_per_file,
            seed=args.seed,
        )

    # Dataset / loader
    ds = ByteDataset(args.train_txt, block_size=args.block_size, max_bytes=args.max_bytes)
    if len(ds) <= 0:
        raise RuntimeError("Dataset too small. Increase --max_bytes or ensure train.txt is non-empty.")

    dl = torch.utils.data.DataLoader(
        ds, batch_size=args.batch_size, shuffle=True, num_workers=0
    )

    # Model config
    cfg = BDHConfig(
        vocab_size=256,
        n_layer=args.n_layer,
        n_embd=args.n_embd,
        n_head=args.n_head,
        mlp_internal_dim_multiplier=args.mult,
        dropout=args.dropout,
    )
    model = BDH(cfg).to(device)

    # Resume if requested
    if args.resume and os.path.exists(args.ckpt_path):
        model.load_state_dict(torch.load(args.ckpt_path, map_location=device))
        print(f"[train] resumed from {args.ckpt_path}")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    model.train()
    step = 0
    while step < args.steps:
        for x, y in dl:
            x, y = x.to(device), y.to(device)

            _, loss = model(x, y)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            if step % args.log_every == 0:
                print(f"[train] step={step} loss={float(loss):.4f}")

            step += 1
            if step >= args.steps:
                break

    ckpt_dir = os.path.dirname(args.ckpt_path)
    if ckpt_dir:
        os.makedirs(ckpt_dir, exist_ok=True)
    torch.save(model.state_dict(), args.ckpt_path)
    print(f"[train] saved checkpoint: {args.ckpt_path}")


def probe(args) -> None:
    """
    Probe a trained model to find shared sparse neurons across multiple inputs.

    What we do here:
      - load checkpoint
      - compute top-k sparse features per layer for each input text
      - intersect the top-k sets across inputs
      - print shared neuron IDs (candidate reusable features)
    """
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    print(f"[probe] device={device}")

    cfg = BDHConfig(
        vocab_size=256,
        n_layer=args.n_layer,
        n_embd=args.n_embd,
        n_head=args.n_head,
        mlp_internal_dim_multiplier=args.mult,
        dropout=0.0,  # important: disable dropout at probe time
    )
    model = BDH(cfg).to(device)

    if not os.path.exists(args.ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {args.ckpt_path}. Train first.")
    model.load_state_dict(torch.load(args.ckpt_path, map_location=device))
    model.eval()

    texts = args.texts
    if len(texts) < 2:
        raise ValueError("Provide at least 2 texts/words to find shared neurons.")

    all_hits = []
    for t in texts:
        hits = top_neurons_for_input(
            model, t, topk=args.topk, aggregate=args.aggregate
        )
        all_hits.append(hits)

    shared = shared_neurons_across_texts(all_hits, topk_intersection=args.topk_intersection)

    print("\nInputs:")
    for i, t in enumerate(texts):
        print(f"  [{i}] {t}")

    print("\nShared neuron IDs per layer (intersection of top-k):")
    for layer, ids in enumerate(shared):
        show = ids[:args.show]
        print(f"\nLayer {layer}: {len(ids)} shared IDs")
        if not show:
            print("  (none) -> try: train longer, increase --topk/--topk_intersection, or probe with short sentences.")
            continue
        for gid in show:
            L, H, Fidx = decode_neuron(gid, cfg)
            print(f"  neuron_id={gid}  (layer={L}, head={H}, feat={Fidx})")

    if args.print_top:
        print("\nTop neurons per input (first few layers):")
        for i, t in enumerate(texts):
            print(f"\n=== Input [{i}] {t} ===")
            hits = all_hits[i]
            for layer in range(min(args.layers_print, len(hits))):
                top = hits[layer][:args.show]
                print(f"  Layer {layer}:")
                for gid, val, _L, H, Fidx in top:
                    print(f"    neuron_id={gid} act={val:.4f} (head={H}, feat={Fidx})")


def main() -> None:
    p = argparse.ArgumentParser()
    sub = p.add_subparsers(dest="cmd", required=True)

    # TRAIN
    pt = sub.add_parser("train")
    pt.add_argument("--data_dir", type=str, default="data_europarl")
    pt.add_argument("--train_txt", type=str, default="data_europarl/train.txt")
    pt.add_argument("--download", action="store_true", help="download+extract Europarl")
    pt.add_argument("--pairs", nargs="+", default=["de-en.tgz", "fr-en.tgz"], help="which Europarl tgz keys")
    pt.add_argument("--rebuild_txt", action="store_true")
    pt.add_argument("--max_lines_per_file", type=int, default=200_000)
    pt.add_argument("--max_bytes", type=int, default=50_000_000)
    pt.add_argument("--block_size", type=int, default=256)
    pt.add_argument("--batch_size", type=int, default=16)
    pt.add_argument("--steps", type=int, default=5000)
    pt.add_argument("--lr", type=float, default=3e-4)
    pt.add_argument("--log_every", type=int, default=100)
    pt.add_argument("--seed", type=int, default=0)
    pt.add_argument("--ckpt_path", type=str, default="checkpoints/bdh_europarl_bytes.pt")
    pt.add_argument("--resume", action="store_true")
    pt.add_argument("--cpu", action="store_true")

    # Model hyperparams (must match probe time)
    pt.add_argument("--n_layer", type=int, default=6)
    pt.add_argument("--n_embd", type=int, default=256)
    pt.add_argument("--n_head", type=int, default=4)
    pt.add_argument("--mult", type=int, default=128)
    pt.add_argument("--dropout", type=float, default=0.1)

    # PROBE
    ppb = sub.add_parser("probe")
    ppb.add_argument("--texts", nargs="+", required=True, help="words/sentences you provide at test time")
    ppb.add_argument("--ckpt_path", type=str, default="checkpoints/bdh_europarl_bytes.pt")
    ppb.add_argument("--topk", type=int, default=300, help="top neurons to compute per layer per input")
    ppb.add_argument("--topk_intersection", type=int, default=200, help="intersection over this many top neurons")
    ppb.add_argument("--show", type=int, default=30, help="how many shared IDs to print per layer")
    ppb.add_argument("--aggregate", choices=["mean", "last"], default="mean", help="aggregate over bytes")
    ppb.add_argument("--print_top", action="store_true", help="also print top neurons per input")
    ppb.add_argument("--layers_print", type=int, default=2)
    ppb.add_argument("--cpu", action="store_true")

    # Model hyperparams (must match training)
    ppb.add_argument("--n_layer", type=int, default=6)
    ppb.add_argument("--n_embd", type=int, default=256)
    ppb.add_argument("--n_head", type=int, default=4)
    ppb.add_argument("--mult", type=int, default=128)

    args = p.parse_args()

    # Reproducibility
    torch.manual_seed(1337)
    random.seed(1337)

    if args.cmd == "train":
        train(args)
    elif args.cmd == "probe":
        probe(args)


if __name__ == "__main__":
    main()

In [None]:
!wc -l bdh_europarl_train_probe.py
!ls -lh bdh_europarl_train_probe.py


In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
print("cleared")


In [None]:
!python -u bdh_europarl_train_probe.py train --download \
  --steps 200 --log_every 10 \
  --batch_size 2 --block_size 128 \
  --n_layer 4 --n_embd 128 --n_head 4 --mult 16 \
  --max_lines_per_file 20000 --max_bytes 5000000


In [None]:
!python -u bdh_europarl_train_probe.py train \
  --steps 2000 --log_every 50 \
  --batch_size 4 --block_size 128 \
  --n_layer 6 --n_embd 128 --n_head 4 --mult 16


In [None]:
!python -u bdh_europarl_train_probe.py train --resume \
  --steps 12000 --log_every 200 \
  --batch_size 4 --block_size 128 \
  --n_layer 6 --n_embd 128 --n_head 4 --mult 16


In [None]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


In [None]:
APP_CODE = r'''

import os, re, importlib.util
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import streamlit as st
from transformers import (MarianMTModel, MarianTokenizer,
                          AutoModelForCausalLM, AutoTokenizer)

st.set_page_config(
    page_title="BDH Monosemanticity Probe | KRITI 2026",
    layout="wide",
    initial_sidebar_state="expanded"
)

# --------------------------------------------------------------
# CSS  —  dark space theme
# --------------------------------------------------------------
st.markdown(r"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;600;700&display=swap');
html,body,[class*="css"]{ font-family:'Space Grotesk',sans-serif; }
.stApp{
  background:#04070f;
  background-image:
    radial-gradient(ellipse at 12% 4%,rgba(18,55,180,.22) 0%,transparent 52%),
    radial-gradient(ellipse at 88% 96%,rgba(80,18,160,.18) 0%,transparent 52%);
}
[data-testid="stSidebar"]{background:#060c1a!important;border-right:1px solid rgba(60,110,255,.20);}
[data-testid="stSidebar"] label{color:#c8dcff!important;font-size:.88rem;}
[data-testid="stSidebar"] h3{
  font-family:'JetBrains Mono',monospace!important;color:#7aa8ff!important;
  font-size:.72rem!important;letter-spacing:.16em;text-transform:uppercase;
  margin-top:1.8rem!important;padding-bottom:5px;border-bottom:1px solid rgba(70,120,255,.22);
}
h1{font-family:'JetBrains Mono',monospace!important;font-size:2.0rem!important;
   font-weight:700!important;color:#f0f6ff!important;}
h2{font-family:'JetBrains Mono',monospace!important;color:#c4dcff!important;
   font-size:1.18rem!important;margin-top:2rem!important;
   border-left:4px solid #4a78f0;padding-left:.9rem;}
h3{font-family:'JetBrains Mono',monospace!important;color:#9ec4ff!important;font-size:.98rem!important;}
p,li{color:#c8dcff;font-size:.97rem;line-height:1.72;}
.stMarkdown p{color:#c8dcff!important;}
[data-testid="stMetric"]{background:rgba(10,18,48,.95)!important;
  border:1px solid rgba(65,115,255,.30)!important;border-radius:10px;padding:.9rem 1.1rem!important;}
[data-testid="stMetric"] label{font-family:'JetBrains Mono',monospace!important;
  color:#7aa8ff!important;font-size:.70rem!important;letter-spacing:.14em;text-transform:uppercase;}
[data-testid="stMetric"] [data-testid="stMetricValue"]{font-family:'JetBrains Mono',monospace!important;
  color:#f0f6ff!important;font-size:1.60rem!important;font-weight:700!important;}
.stButton>button{background:linear-gradient(135deg,#162870 0%,#261668 100%)!important;
  border:1px solid rgba(110,165,255,.55)!important;color:#ddeeff!important;
  font-family:'JetBrains Mono',monospace!important;font-size:.78rem!important;
  letter-spacing:.10em;border-radius:7px!important;padding:.70rem 1.6rem!important;
  transition:all .18s ease!important;}
.stButton>button:hover{background:linear-gradient(135deg,#1f3aa8 0%,#321f9e 100%)!important;
  box-shadow:0 0 20px rgba(80,130,255,.4)!important;}
.stTextInput>div>div>input,.stTextArea>div>div>textarea,.stNumberInput>div>div>input{
  background:rgba(6,12,32,.98)!important;border:1px solid rgba(80,130,255,.42)!important;
  color:#f0f6ff!important;border-radius:7px!important;
  font-family:'JetBrains Mono',monospace!important;font-size:.88rem!important;}
.stTabs [data-baseweb="tab-list"]{background:rgba(6,10,28,.94);border-bottom:1px solid rgba(60,100,230,.22);}
.stTabs [data-baseweb="tab"]{font-family:'JetBrains Mono',monospace!important;
  font-size:.76rem!important;letter-spacing:.09em;color:#6882b8!important;
  padding:.76rem 1.30rem!important;text-transform:uppercase;}
.stTabs [aria-selected="true"]{color:#a8caff!important;
  border-bottom:2px solid #4a78f0!important;background:rgba(22,42,110,.30)!important;}
.info-strip{background:rgba(10,20,52,.90);border:1px solid rgba(60,100,230,.24);
  border-radius:11px;padding:1.1rem 1.4rem;margin:.8rem 0 1.3rem 0;
  font-size:.95rem;color:#c8dcff;line-height:1.78;}
.token-hdr{font-family:'JetBrains Mono',monospace;font-size:.84rem;color:#5290ff;
  letter-spacing:.11em;text-transform:uppercase;margin:1.5rem 0 .55rem 0;
  padding:.55rem 1rem;border-left:4px solid #5290ff;
  background:rgba(18,38,100,.32);border-radius:0 7px 7px 0;}
.badge-pass{display:inline-block;background:rgba(18,155,75,.20);
  border:1px solid rgba(35,210,100,.50);color:#6dffa8;
  font-family:'JetBrains Mono',monospace;font-size:.70rem;padding:3px 11px;
  border-radius:7px;letter-spacing:.12em;text-transform:uppercase;font-weight:700;}
.badge-weak{display:inline-block;background:rgba(200,120,18,.20);
  border:1px solid rgba(240,160,40,.50);color:#ffd060;
  font-family:'JetBrains Mono',monospace;font-size:.70rem;padding:3px 11px;
  border-radius:7px;letter-spacing:.12em;text-transform:uppercase;font-weight:700;}
.concept-row{font-family:'JetBrains Mono',monospace;font-size:1.0rem;
  color:#f0f6ff;font-weight:700;}
.neuron-tag{display:inline-block;background:rgba(38,82,215,.24);
  border:1px solid rgba(100,165,255,.52);color:#c4dcff;
  font-family:'JetBrains Mono',monospace;font-size:.70rem;padding:3px 11px;
  border-radius:7px;margin-left:8px;font-weight:700;}
.subtitle{font-family:'JetBrains Mono',monospace;font-size:.78rem;color:#6882b8;
  letter-spacing:.16em;text-transform:uppercase;margin-bottom:1.3rem;}
.compare-box{background:rgba(10,20,55,.88);border:1px solid rgba(65,110,255,.28);
  border-radius:12px;padding:1.4rem 1.6rem;margin:1rem 0;}
.bdh-label{color:#38e090;font-family:'JetBrains Mono',monospace;font-weight:700;font-size:1.0rem;}
.tfm-label{color:#ff5a72;font-family:'JetBrains Mono',monospace;font-weight:700;font-size:1.0rem;}
hr{border-color:rgba(60,100,230,.14)!important;margin:1.8rem 0!important;}

/* Summary scorecard styles */
.score-card{background:linear-gradient(135deg,rgba(8,16,48,.98) 0%,rgba(12,8,40,.98) 100%);
  border:1px solid rgba(65,110,255,.35);border-radius:16px;
  padding:2rem 2.2rem;margin:1.2rem 0;overflow:hidden;}
.score-metric-row{display:grid;grid-template-columns:30% 30% 30% 10%;
  align-items:center;padding:.85rem 1rem;margin:.4rem 0;
  background:rgba(16,28,72,.60);border-radius:10px;
  border:1px solid rgba(55,90,200,.20);}
.score-metric-row:hover{background:rgba(20,38,95,.80);
  border-color:rgba(80,130,255,.35);}
.score-metric-name{font-family:'JetBrains Mono',monospace;
  font-size:.82rem;color:#8aaedf;line-height:1.5;}
.score-metric-sub{font-size:.70rem;color:#445878;margin-top:2px;}
.score-val-win{font-family:'JetBrains Mono',monospace;font-size:1.05rem;
  font-weight:700;color:#38e090;}
.score-val-lose{font-family:'JetBrains Mono',monospace;font-size:1.05rem;
  font-weight:700;color:#ff5a72;}
.score-val-neutral{font-family:'JetBrains Mono',monospace;font-size:1.05rem;
  font-weight:700;color:#c8dcff;}
.score-val-sub{font-size:.72rem;margin-top:2px;}
.score-badge-win{display:inline-flex;align-items:center;justify-content:center;
  background:rgba(25,180,90,.18);border:1px solid rgba(56,224,144,.55);
  color:#38e090;font-family:'JetBrains Mono',monospace;font-size:.66rem;
  padding:4px 10px;border-radius:20px;letter-spacing:.10em;font-weight:700;}
.score-badge-lose{display:inline-flex;align-items:center;justify-content:center;
  background:rgba(180,25,55,.18);border:1px solid rgba(255,90,114,.55);
  color:#ff5a72;font-family:'JetBrains Mono',monospace;font-size:.66rem;
  padding:4px 10px;border-radius:20px;letter-spacing:.10em;font-weight:700;}
.score-badge-tie{display:inline-flex;align-items:center;justify-content:center;
  background:rgba(120,120,50,.18);border:1px solid rgba(255,200,60,.45);
  color:#ffc040;font-family:'JetBrains Mono',monospace;font-size:.66rem;
  padding:4px 10px;border-radius:20px;letter-spacing:.10em;font-weight:700;}
.score-bar-wrap{height:8px;background:rgba(30,50,110,.50);
  border-radius:4px;margin-top:6px;overflow:hidden;}
.score-bar-fill{height:100%;border-radius:4px;}
.score-hdr{font-family:'JetBrains Mono',monospace;font-size:.68rem;
  color:#445878;letter-spacing:.14em;text-transform:uppercase;
  padding:.5rem 1rem;margin-bottom:.5rem;}
.verdict-strip{text-align:center;padding:1.4rem;border-radius:12px;
  margin-top:1.4rem;font-family:'JetBrains Mono',monospace;
  font-size:1.0rem;font-weight:700;letter-spacing:.04em;}
.expected-box{background:rgba(10,18,50,.85);border:1px solid rgba(55,90,200,.30);
  border-radius:10px;padding:1.1rem 1.4rem;margin:.6rem 0;
  font-family:'JetBrains Mono',monospace;font-size:.82rem;color:#8aaedf;}
.expected-box strong{color:#c4dcff;}
.expected-box span.win{color:#38e090;}
.expected-box span.lose{color:#ff5a72;}
</style>
""", unsafe_allow_html=True)


# ==============================================================
# MATPLOTLIB THEME
# ==============================================================
PAL        = ["#5a9cff","#ff5a72","#38e090","#ffc040","#c878ff","#38d4f8","#ff9040"]
C_POS      = "#5a9cff"
C_NEG      = "#ff5060"
C_BDH      = "#38e090"
C_TFM      = "#ff5a72"
C_AX       = "#04070f"
C_EDGE     = "#243666"
C_GRID     = "#151e3a"
C_TITLE    = "#f0f6ff"
C_LABEL    = "#e0eeff"
C_TICK     = "#c4dcff"
C_LEG      = "#f0f6ff"

_TS  = 28
_LS  = 21
_KS  = 18
_AS  = 16
_LEG = 17
_LTI = 15

_BBOX = dict(facecolor="#ffffff", alpha=0.95, edgecolor="#90aadd",
             pad=4, boxstyle="round,pad=0.30")
_TC   = "#060d22"
_TW   = "bold"

EPS = 1e-9

# Sparsity threshold for Transformer (|GELU| below this is counted as inactive).
# GELU produces negative values for negative pre-activations; these are
# functionally suppressed and should count toward "inactive" neurons.
TFM_SPARSE_THRESH = 0.02


def _apply_theme():
    plt.rcParams.update({
        "figure.facecolor": C_AX,   "axes.facecolor":  C_AX,
        "axes.edgecolor":   C_EDGE, "axes.labelcolor": C_LABEL,
        "axes.titlecolor":  C_TITLE,"axes.titlesize":  _TS,
        "axes.titlepad":    24,     "axes.labelsize":  _LS,
        "xtick.labelsize":  _KS,    "ytick.labelsize": _KS,
        "legend.fontsize":  _LEG,   "legend.title_fontsize": _LTI,
        "xtick.color":      C_TICK, "ytick.color":     C_TICK,
        "text.color":       C_LABEL,"grid.color":      C_GRID,
        "grid.alpha":       .9,     "axes.grid":       True,
        "axes.grid.axis":   "y",    "legend.facecolor":"#0a1428",
        "legend.edgecolor": "#2e4880",
        "figure.dpi":       160,    "savefig.dpi":     240,
        "font.family":      "monospace",
        "axes.spines.top":  False,  "axes.spines.right": False,
        "xtick.major.pad":  10,     "ytick.major.pad":    8,
        "axes.linewidth":   1.5,
    })

_apply_theme()


def _tick_boxes(ax):
    ax.figure.canvas.draw()
    for lab in ax.get_xticklabels() + ax.get_yticklabels():
        lab.set_color(_TC)
        lab.set_fontweight(_TW)
        lab.set_bbox(dict(facecolor="#ffffff", alpha=0.92,
                          edgecolor="#9ab0e0", pad=3, boxstyle="round,pad=0.22"))
        lab.set_fontsize(_KS)


def _style(ax):
    ax.title.set_color(C_TITLE)
    ax.title.set_fontsize(_TS)
    ax.title.set_bbox(dict(facecolor="#04070f", alpha=.60, edgecolor="none", pad=10))
    for attr in (ax.xaxis.label, ax.yaxis.label):
        attr.set_color(C_LABEL)
        attr.set_fontsize(_LS)
        attr.set_bbox(dict(facecolor="#04070f", alpha=.40, edgecolor="none", pad=6))


def _legend(leg):
    if not leg:
        return
    for t in leg.get_texts():
        t.set_color(C_LEG)
        t.set_fontsize(_LEG)
    if leg.get_title():
        leg.get_title().set_color(C_LEG)
        leg.get_title().set_fontsize(_LTI)


def _tight(fig, ax):
    fig.canvas.draw()
    _tick_boxes(ax)
    fig.tight_layout(pad=3.2)


def _fig(w=30, h=12):
    _apply_theme()
    return plt.figure(figsize=(w, h))


def _ann_bars(ax, bars, vals, vmax=None):
    ref = vmax if vmax else (max(vals) if len(vals) else 1)
    for bar, v in zip(bars, vals):
        if v > 5e-4:
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + ref * .014,
                f"{v:.3f}",
                ha="center", va="bottom", fontsize=_AS,
                fontfamily="monospace", color=_TC, fontweight="bold",
                bbox=dict(facecolor="#ffffff", alpha=.93,
                          edgecolor="#90aadd", pad=3, boxstyle="round,pad=0.22")
            )


# ==============================================================
# TRANSLATION UTILITIES
# ==============================================================
ALL_LANGS = {
    "German":  "Helsinki-NLP/opus-mt-en-de",
    "French":  "Helsinki-NLP/opus-mt-en-fr",
    "Spanish": "Helsinki-NLP/opus-mt-en-es",
    "Italian": "Helsinki-NLP/opus-mt-en-it",
}
LANG_TAGS = {"German": "DE", "French": "FR", "Spanish": "ES", "Italian": "IT"}


@st.cache_resource(show_spinner=False)
def get_translator(lang, device):
    tok = MarianTokenizer.from_pretrained(ALL_LANGS[lang])
    mdl = MarianMTModel.from_pretrained(ALL_LANGS[lang]).to(device)
    mdl.eval()
    return tok, mdl


def translate(lang, text, device):
    tok, mdl = get_translator(lang, device)
    batch = tok([text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        out = mdl.generate(**batch, max_new_tokens=96)
    return tok.decode(out[0], skip_special_tokens=True)


# ==============================================================
# BDH MODEL LOADER
# ==============================================================
@st.cache_resource(show_spinner=True)
def load_bdh(script_path, ckpt_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    spec   = importlib.util.spec_from_file_location("bdhmod", script_path)
    bdhmod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(bdhmod)

    ids_for_text = bdhmod.ids_for_text
    neuron_id    = bdhmod.neuron_id

    state      = torch.load(ckpt_path, map_location=device)
    nh, D, N_t = state["encoder"].shape
    mult       = int((N_t * nh) // D)

    cfg = bdhmod.BDHConfig(
        vocab_size=256, n_layer=6, n_embd=D,
        n_head=nh, mlp_internal_dim_multiplier=mult, dropout=0.0
    )
    mdl = bdhmod.BDH(cfg).to(device)
    mdl.load_state_dict(state, strict=True)
    mdl.eval()

    N = (cfg.n_embd * cfg.mlp_internal_dim_multiplier) // cfg.n_head
    return device, cfg, mdl, N, ids_for_text, neuron_id


# ==============================================================
# TRANSFORMER (distilgpt2) LOADER
# ==============================================================
@st.cache_resource(show_spinner=True)
def load_transformer(device):
    tok = AutoTokenizer.from_pretrained("distilgpt2")
    mdl = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
    mdl.eval()
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return mdl, tok


# ==============================================================
# ACTIVATION EXTRACTION
# ==============================================================

@torch.no_grad()
def bdh_token_acts(text, layer_idx, device, model, ids_for_text):
    """
    Return the per-token BDH activation matrix for one layer.
    Shape: (T, total_neurons)  — ReLU activations, exact zeros preserved.
    """
    text = text or " "
    x    = ids_for_text(text).to(device)
    _, _, sparse = model(x, return_sparse=True)
    acts = sparse[layer_idx][0]          # (n_head, T, N_per_head)
    T    = acts.shape[1]
    return acts.permute(1, 0, 2).reshape(T, -1).detach().float().cpu().numpy()


@torch.no_grad()
def tfm_token_acts(tfm_mdl, tfm_tok, text, layer_idx, device):
    """
    Return per-token MLP activations (post-GELU) for one distilgpt2 layer.
    Shape: (T, 3072).

    IMPORTANT: GELU(x) can be negative for negative x (GELU(-x) ≈ -x * Phi(-x)).
    For sparsity analysis, we treat |activation| < TFM_SPARSE_THRESH as inactive,
    since those neurons contribute negligible magnitude regardless of sign.
    """
    text = text or " "
    enc  = tfm_tok(text, return_tensors="pt",
                   truncation=True, max_length=128).to(device)
    buf  = {}

    def _hook(module, inp, out):
        buf["h"] = out.detach()

    handle = tfm_mdl.transformer.h[layer_idx].mlp.c_fc.register_forward_hook(_hook)
    tfm_mdl(**enc)
    handle.remove()

    raw = buf["h"].squeeze(0)             # (T, 3072)
    return F.gelu(raw).cpu().numpy()      # (T, 3072)


def bdh_sparsity_pct(token_acts_matrix):
    """
    Fraction of (token, neuron) pairs with ReLU activation > 0.
    Uses exact-zero threshold appropriate for ReLU.
    Returns percentage active (non-zero).
    """
    return float(np.mean(token_acts_matrix > 0.0)) * 100.0


def tfm_sparsity_pct(token_acts_matrix, thresh=TFM_SPARSE_THRESH):
    """
    Fraction of (token, neuron) pairs where |GELU activation| > thresh.
    Uses absolute value because GELU produces negative values for negative
    pre-activations, which are functionally near-zero but count as 'active'
    if not handled with abs().
    Returns percentage active (above threshold in magnitude).
    """
    return float(np.mean(np.abs(token_acts_matrix) > thresh)) * 100.0


def activation_entropy(acts_1d):
    """
    Shannon entropy (bits) of the activation magnitude distribution.
    Lower entropy => energy concentrated in few neurons => monosemantic.
    Higher entropy => energy spread across all neurons => polysemantic.
    """
    a = np.abs(acts_1d)
    s = a.sum() + EPS
    p = a / s
    p = p[p > 1e-12]
    return float(-np.sum(p * np.log2(p)))


def gini_coefficient(acts_1d):
    """
    Gini coefficient of the absolute activation distribution.
    0 = perfectly equal (dense / polysemantic).
    1 = perfectly concentrated (sparse / monosemantic).
    Complements entropy as a second sparsity measure.
    """
    a = np.abs(acts_1d)
    a = np.sort(a)
    n = len(a)
    if n == 0 or a.sum() < EPS:
        return 0.0
    idx = np.arange(1, n + 1)
    return float((2 * (idx * a).sum()) / (n * a.sum() + EPS) - (n + 1) / n)


# ==============================================================
# BDH HELPER FUNCTIONS
# ==============================================================

def jaccard(a, b):
    return len(a & b) / (len(a | b) + EPS)


def decode_gid(gid, cfg, N):
    pl    = cfg.n_head * N
    layer = gid // pl
    rem   = gid % pl
    return layer, rem // N, rem % N


def wsplit(text):
    return [w.strip() for w in re.findall(r"\S+", text) if w.strip()]


def clean_token(tok_str):
    return re.sub(r"^\W+|\W+$", "", tok_str, flags=re.UNICODE).strip()


def pos_wrap(c):
    return [
        f"The text is about {c}.",
        f"This sentence mentions {c}.",
        f"I saw a {c} yesterday.",
    ]


@torch.no_grad()
def _flat_mean(text, layer_idx, device, model, ids_for_text):
    return bdh_token_acts(text, layer_idx, device, model, ids_for_text).mean(0)


@torch.no_grad()
def topk_list(text, layer_idx, k, device, cfg, model, N, ids_for_text, neuron_id):
    fa   = _flat_mean(text, layer_idx, device, model, ids_for_text)
    k    = min(int(k), fa.shape[0])
    idxs = np.argsort(-fa)[:k]
    return [neuron_id(layer_idx, int(ix) // N, int(ix) % N, cfg.n_head, N) for ix in idxs]


@torch.no_grad()
def topk_set(text, layer_idx, k, device, cfg, model, N, ids_for_text, neuron_id):
    return set(topk_list(text, layer_idx, k, device, cfg, model, N, ids_for_text, neuron_id))


@torch.no_grad()
def act_mean(text, layer_idx, gid, device, cfg, model, N, ids_for_text):
    layer, head, feat = decode_gid(gid, cfg, N)
    if layer != layer_idx:
        return 0.0
    text = text or " "
    x    = ids_for_text(text).to(device)
    _, _, sparse = model(x, return_sparse=True)
    return float(sparse[layer_idx][0][head, :, feat].mean().item())


# ==============================================================
# PLOT FUNCTIONS  —  Tabs 1, 2, 3
# ==============================================================

def plot_line(df, tags, title):
    fig = _fig()
    ax  = fig.add_subplot(111)
    x   = np.arange(len(df))
    for i, t in enumerate(tags):
        c = PAL[i % len(PAL)]
        ax.plot(x, df[f"{t}_overlap"], marker="o", lw=4.0, color=c,
                label=t, ms=12, markerfacecolor=C_AX,
                markeredgewidth=3.5, markeredgecolor=c)
        ax.fill_between(x, df[f"{t}_overlap"], alpha=.12, color=c)
        for xi, yi in zip(x, df[f"{t}_overlap"]):
            if yi > 0:
                ax.annotate(str(int(yi)), (xi, yi),
                            textcoords="offset points", xytext=(0, 14),
                            ha="center", fontsize=_AS + 1, color=_TC,
                            fontfamily="monospace", fontweight="bold",
                            bbox=dict(facecolor="#fff", alpha=.94,
                                      edgecolor="#90aadd", pad=3,
                                      boxstyle="round,pad=0.26"))
    ax.set_xticks(x)
    ax.set_xticklabels(df["EN_word"], rotation=22, ha="right")
    ax.set_ylabel("Shared features with EN (best-match)", labelpad=14)
    ax.set_title(title)
    _style(ax)
    _legend(ax.legend(framealpha=.97, title="Language", title_fontsize=_LTI))
    _tight(fig, ax)
    return fig


def plot_heat(df, tags, title):
    heat = np.vstack([df[f"{t}_overlap"].values for t in tags])
    fig  = _fig(30, 4 + 2.2 * len(tags))
    ax   = fig.add_subplot(111)
    im   = ax.imshow(heat, aspect="auto", cmap="plasma",
                     interpolation="nearest", vmin=0, vmax=max(heat.max(), 1))
    ax.set_yticks(np.arange(len(tags)))
    ax.set_yticklabels(tags)
    ax.set_xticks(np.arange(len(df)))
    ax.set_xticklabels(df["EN_word"], rotation=22, ha="right")
    for yi in range(len(tags)):
        for xi in range(len(df)):
            ax.text(xi, yi, str(int(heat[yi, xi])),
                    ha="center", va="center", fontsize=_AS + 2,
                    color=_TC, fontfamily="monospace", fontweight="bold",
                    bbox=dict(facecolor="#fff", alpha=.85,
                              edgecolor="none", pad=2.5,
                              boxstyle="round,pad=0.22"))
    cb = fig.colorbar(im, ax=ax, fraction=.025, pad=.02)
    cb.set_label("Overlap count", color=C_LABEL, fontsize=_LS - 1)
    for t in cb.ax.get_yticklabels():
        t.set_color(_TC)
        t.set_fontweight(_TW)
        t.set_bbox(_BBOX)
    ax.set_title(title)
    _style(ax)
    _tight(fig, ax)
    return fig


def plot_baseline(names, actual, base, title):
    fig = _fig()
    ax  = fig.add_subplot(111)
    x   = np.arange(len(names))
    bars = ax.bar(x, actual, color=C_POS, alpha=.92, edgecolor="none",
                  width=.65, label="Actual (EN vs translations)")
    ax.axhline(base.mean(), color=C_NEG, ls="--", lw=3.5,
               label=f"Shuffle baseline mean={base.mean():.3f}")
    ax.axhspan(base.mean() - base.std(), base.mean() + base.std(),
               alpha=.14, color=C_NEG, label="Baseline +/- 1 sigma")
    _ann_bars(ax, bars, actual, vmax=1.0)
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=22, ha="right")
    ax.set_ylim(0, 1)
    ax.set_ylabel("Jaccard overlap", labelpad=14)
    ax.set_title(title)
    _style(ax)
    _legend(ax.legend(framealpha=.97))
    _tight(fig, ax)
    return fig


def plot_pn_bar(pl, pv, nl, nv, title):
    al     = pl + nl
    av     = np.concatenate([pv, nv])
    colors = [C_POS] * len(pl) + [C_NEG] * len(nl)
    vm     = max(float(av.max()) * 1.28, 1e-3)
    fig    = _fig()
    ax     = fig.add_subplot(111)
    bars   = ax.bar(np.arange(len(al)), av, color=colors,
                    alpha=.92, edgecolor="none", width=.72)
    _ann_bars(ax, bars, av, vmax=vm)
    ax.axvline(len(pl) - .5, color="#8aaae0", lw=3.0, alpha=.8, ls="--")
    ax.set_xticks(np.arange(len(al)))
    ax.set_xticklabels(al, rotation=20, ha="right")
    ax.set_ylim(0, vm)
    ax.set_ylabel("Mean activation", labelpad=14)
    ax.set_title(title)
    _style(ax)
    _legend(ax.legend(
        handles=[mpatches.Patch(color=C_POS, label="POS"),
                 mpatches.Patch(color=C_NEG, label="NEG")],
        framealpha=.97))
    _tight(fig, ax)
    return fig


def plot_pn_box(pv, nv, title):
    fig = _fig(18, 12)
    ax  = fig.add_subplot(111)
    bp  = ax.boxplot(
        [pv, nv], labels=["POS", "NEG"], showmeans=True, patch_artist=True,
        medianprops=dict(color="white", lw=4.0),
        meanprops=dict(marker="D", markerfacecolor="#ffe070",
                       markeredgecolor="none", markersize=14),
        flierprops=dict(marker="o", markerfacecolor="#c4dcff",
                        markersize=9, alpha=.7)
    )
    bp["boxes"][0].set_facecolor(C_POS); bp["boxes"][0].set_alpha(.55)
    bp["boxes"][1].set_facecolor(C_NEG); bp["boxes"][1].set_alpha(.55)
    for w in bp["whiskers"] + bp["caps"]:
        w.set_color("#a8c0e8")
        w.set_lw(3.0)
    ax.set_ylabel("Mean activation", labelpad=14)
    ax.tick_params(axis="x", labelsize=_KS + 5)
    ax.set_title(title)
    _style(ax)
    _tight(fig, ax)
    return fig


# ==============================================================
# COMPARISON PLOT FUNCTIONS  —  Tab 4
# ==============================================================

def plot_sparsity_comparison(bdh_token_mat, tfm_token_mat, title, tfm_thresh):
    """
    Side-by-side sorted activation profiles showing per-token sparsity.

    BDH panel: ReLU exact-zero boundary marked in yellow.
    Transformer panel: |GELU| < tfm_thresh boundary marked.

    The two sparsity metrics are NOT the same scale:
      - BDH uses strict threshold=0 (ReLU zeros are exact by design).
      - Transformer uses |activation| > tfm_thresh because GELU produces
        small-magnitude negative values that are functionally suppressed.

    Returns  (fig, pct_bdh_active, pct_tfm_active).
    Lower percentage = sparser = better monosemanticity.
    """
    _apply_theme()
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(34, 13))

    # ---- BDH panel (exact ReLU zeros) ----
    bdh_mean = bdh_token_mat.mean(0)
    pct_bdh  = bdh_sparsity_pct(bdh_token_mat)
    y_bdh    = np.sort(bdh_mean)[::-1][:500]
    ax1.bar(np.arange(len(y_bdh)), y_bdh, color=C_BDH,
            alpha=.88, edgecolor="none", width=1.0)
    nonzero_count = int(np.sum(bdh_mean > 0))
    ax1.axvline(min(nonzero_count, 500), color="#ffe070", lw=3.5, ls="--",
                label="ReLU zero boundary (exact)")
    ax1.set_title(
        f"BDH  —  {pct_bdh:.1f}% neurons active per token  (ReLU > 0)",
        color=C_BDH, fontsize=_TS, pad=20
    )
    ax1.set_xlabel("Neuron rank (by mean activation across tokens)", labelpad=12)
    ax1.set_ylabel("Mean activation", labelpad=12)
    ax1.annotate(
        f"SPARSE\n{pct_bdh:.1f}% active\nReLU exact zeros",
        xy=(len(y_bdh) * .58, float(y_bdh.max()) * .55),
        fontsize=_AS + 4, color=_TC, fontfamily="monospace",
        fontweight="bold", ha="center",
        bbox=dict(facecolor="#ffffff", alpha=.93,
                  edgecolor=C_BDH, pad=8, boxstyle="round,pad=0.4")
    )
    _style(ax1); _tick_boxes(ax1)
    _legend(ax1.legend(framealpha=.97))

    # ---- Transformer panel (|GELU| > threshold) ----
    tfm_abs  = np.abs(tfm_token_mat)
    tfm_mean = tfm_abs.mean(0)
    pct_tfm  = tfm_sparsity_pct(tfm_token_mat, thresh=tfm_thresh)
    y_tfm    = np.sort(tfm_mean)[::-1][:500]
    ax2.bar(np.arange(len(y_tfm)), y_tfm, color=C_TFM,
            alpha=.80, edgecolor="none", width=1.0)
    # Mark the effective "inactive" threshold level on the y-axis
    ax2.axhline(tfm_thresh, color="#ffe070", lw=3.5, ls="--",
                label=f"|GELU| = {tfm_thresh:.3f} inactive threshold")
    # Shade the region below threshold
    ax2.fill_between(np.arange(len(y_tfm)),
                     0, min(tfm_thresh, float(y_tfm.min()) + tfm_thresh),
                     alpha=.20, color="#ffe070",
                     label="Near-zero region (inactive)")
    ax2.set_title(
        f"Transformer (distilGPT-2)  —  {pct_tfm:.1f}% neurons active per token"
        f"  (|GELU| > {tfm_thresh:.3f})",
        color=C_TFM, fontsize=_TS, pad=20
    )
    ax2.set_xlabel("Neuron rank (by mean |activation| across tokens)", labelpad=12)
    ax2.set_ylabel("Mean |activation|", labelpad=12)
    ax2.annotate(
        f"DENSE\n{pct_tfm:.1f}% active\nGELU near-universal",
        xy=(len(y_tfm) * .58, float(y_tfm.max()) * .55),
        fontsize=_AS + 4, color=_TC, fontfamily="monospace",
        fontweight="bold", ha="center",
        bbox=dict(facecolor="#ffffff", alpha=.93,
                  edgecolor=C_TFM, pad=8, boxstyle="round,pad=0.4")
    )
    _style(ax2); _tick_boxes(ax2)
    _legend(ax2.legend(framealpha=.97))

    fig.suptitle(title, fontsize=_TS + 3, color=C_TITLE, y=1.01,
                 fontfamily="monospace", fontweight="bold",
                 bbox=dict(facecolor="#04070f", alpha=.6,
                           edgecolor="none", pad=10))
    fig.tight_layout(pad=3.5)
    return fig, pct_bdh, pct_tfm


def plot_entropy_comparison(concepts, bdh_entropies, tfm_entropies, title):
    """
    Grouped bar chart comparing Shannon activation entropy per concept.
    Lower entropy = monosemantic. BDH expected to have lower entropy.
    """
    _apply_theme()
    fig, ax = plt.subplots(figsize=(32, 13))
    x = np.arange(len(concepts))
    w = 0.38

    bars_b = ax.bar(x - w / 2, bdh_entropies, w, color=C_BDH,
                    alpha=.88, edgecolor="none",
                    label="BDH (low entropy = concentrated = monosemantic)")
    bars_t = ax.bar(x + w / 2, tfm_entropies, w, color=C_TFM,
                    alpha=.80, edgecolor="none",
                    label="Transformer (high entropy = diffuse = polysemantic)")

    vm = max(max(bdh_entropies), max(tfm_entropies)) * 1.22
    _ann_bars(ax, bars_b, bdh_entropies, vmax=vm)
    _ann_bars(ax, bars_t, tfm_entropies, vmax=vm)

    ax.set_xticks(x)
    ax.set_xticklabels(concepts, rotation=20, ha="right")
    ax.set_ylabel("Activation entropy (bits)\n"
                  "  lower = more selective / monosemantic", labelpad=14)
    ax.set_title(title)
    _style(ax)
    _legend(ax.legend(framealpha=.97))
    _tight(fig, ax)
    return fig


def plot_crosslingual_consistency(concepts, bdh_jacc, tfm_jacc, title):
    """
    Grouped bar chart of mean Jaccard overlap EN <-> translations.
    Higher = same neurons respond regardless of language.
    """
    _apply_theme()
    fig, ax = plt.subplots(figsize=(32, 13))
    x = np.arange(len(concepts))
    w = 0.38

    bars_b = ax.bar(x - w / 2, bdh_jacc, w, color=C_BDH,
                    alpha=.88, edgecolor="none",
                    label="BDH — cross-lingual neuron overlap")
    bars_t = ax.bar(x + w / 2, tfm_jacc, w, color=C_TFM,
                    alpha=.80, edgecolor="none",
                    label="Transformer — cross-lingual overlap")

    vm = max(float(np.concatenate([bdh_jacc, tfm_jacc]).max()) * 1.28, .1)
    _ann_bars(ax, bars_b, bdh_jacc, vmax=vm)
    _ann_bars(ax, bars_t, tfm_jacc, vmax=vm)
    ax.set_xticks(x)
    ax.set_xticklabels(concepts, rotation=20, ha="right")
    ax.set_ylim(0, vm)
    ax.set_ylabel("Mean Jaccard overlap (EN <-> translations)", labelpad=14)
    ax.set_title(title)
    _style(ax)
    _legend(ax.legend(framealpha=.97))
    _tight(fig, ax)
    return fig


def plot_activation_heatmap_compare(words, bdh_vecs, tfm_vecs, n_show=60):
    """
    Side-by-side heatmaps using middle token only to preserve ReLU zeros.

    KEY FIX: For BDH we show (n_show//2) TOP active neurons AND (n_show//2)
    ZERO/INACTIVE neurons side-by-side.  Selecting only top neurons hides the
    sparsity because every selected neuron is by definition active.  Including
    the zero neurons makes the dark background and bright spots clearly visible.

    Transformer: top neurons only — they are all active (GELU near-universal).
    """
    _apply_theme()
    fig, (ax1, ax2) = plt.subplots(
        1, 2, figsize=(36, max(9, 3 + 2.0 * len(words)))
    )

    def _heat(ax, mat, title, cmap, color, xlabel="Neurons (active | inactive)"):
        im = ax.imshow(mat, aspect="auto", cmap=cmap,
                       interpolation="nearest", vmin=0, vmax=1)
        ax.set_yticks(np.arange(len(words)))
        ax.set_yticklabels(words, fontsize=_KS + 1)
        ax.set_xlabel(xlabel, labelpad=12)
        ax.set_title(title, color=color, fontsize=_TS, pad=18)
        fig.colorbar(im, ax=ax, fraction=.025, pad=.02)
        _style(ax)
        ax.figure.canvas.draw()
        for lab in ax.get_yticklabels():
            lab.set_color(_TC)
            lab.set_fontweight(_TW)
            lab.set_bbox(dict(facecolor="#fff", alpha=.90,
                              edgecolor="#9ab0e0", pad=3,
                              boxstyle="round,pad=0.22"))

    # ---- BDH: interleave top-active + zero neurons ----
    n_active_show = n_show // 2
    n_zero_show   = n_show // 2
    total_bdh     = bdh_vecs.shape[1]

    sorted_by_max = np.argsort(-bdh_vecs.max(0))          # highest → lowest
    top_active_idx = sorted_by_max[:n_active_show]

    # Zero neurons: where ALL concepts have activation = 0 (true ReLU zeros)
    zero_mask = (bdh_vecs.max(0) == 0)
    zero_indices = np.where(zero_mask)[0]
    if len(zero_indices) >= n_zero_show:
        rng = np.random.default_rng(42)
        zero_sample = rng.choice(zero_indices, n_zero_show, replace=False)
    else:
        # Fall back to the least-active neurons if not enough true zeros
        zero_sample = sorted_by_max[total_bdh - n_zero_show:]

    # Concatenate: left = active (bright), right = inactive (dark)
    top_idx_b = np.concatenate([top_active_idx, zero_sample])
    B          = bdh_vecs[:, top_idx_b]
    B_norm     = B / (B.max() + EPS)

    # Draw a vertical divider between active and zero columns
    _heat(ax1, B_norm,
          f"BDH  —  Sparse Activation Heatmap  (middle token)\n"
          f"Left {n_active_show} cols: TOP active neurons  |  "
          f"Right {n_zero_show} cols: ZERO / inactive neurons",
          "viridis", C_BDH,
          xlabel=f"← {n_active_show} top active  |  {n_zero_show} zero inactive →")
    # Divider line between active and zero halves
    ax1.axvline(n_active_show - 0.5, color="#ffe070", lw=3.0, ls="--", alpha=.9)

    # ---- Transformer: top neurons only — all active ----
    top_idx_t = np.argsort(-np.abs(tfm_vecs).max(0))[:n_show]
    T          = np.abs(tfm_vecs[:, top_idx_t])
    T_norm     = T / (T.max() + EPS)
    _heat(ax2, T_norm,
          f"Transformer  —  Dense Activation Heatmap  (middle token)\n"
          f"Top {n_show} neurons — ALL uniformly active (GELU never zero)",
          "plasma", C_TFM,
          xlabel=f"Top {n_show} neurons ranked by |activation|")

    fig.suptitle(
        "Activation Heatmap: BDH (sparse, left=active / right=zero) "
        "vs Transformer (dense, all active)  —  middle token only",
        fontsize=_TS + 1, color=C_TITLE, y=1.01,
        fontfamily="monospace", fontweight="bold"
    )
    fig.tight_layout(pad=3.5)
    return fig


def _bar_html(pct, color, max_pct=100):
    """Render a small inline progress bar as HTML."""
    w = min(100, max(2, pct / max_pct * 100))
    return (f'<div class="score-bar-wrap">'
            f'<div class="score-bar-fill" style="width:{w:.1f}%;'
            f'background:{color};"></div></div>')


def render_summary_table(summary, TOPK_SET_BDH, TOPK_SET_TFM, LANGS, tfm_thresh):
    """
    Build the full summary scorecard HTML with winner badges, value bars,
    interpretation notes, and a final verdict strip.
    """
    bdh_sp  = summary.get("bdh_sparsity",  float("nan"))
    tfm_sp  = summary.get("tfm_sparsity",  float("nan"))
    bdh_ent = summary.get("bdh_entropy",   float("nan"))
    tfm_ent = summary.get("tfm_entropy",   float("nan"))
    bdh_jac = summary.get("bdh_jaccard",   float("nan"))
    tfm_jac = summary.get("tfm_jaccard",   float("nan"))

    # Determine per-metric winners
    sp_bdh_wins  = bdh_sp  < tfm_sp   if not (np.isnan(bdh_sp)  or np.isnan(tfm_sp))  else None
    ent_bdh_wins = bdh_ent < tfm_ent  if not (np.isnan(bdh_ent) or np.isnan(tfm_ent)) else None
    jac_bdh_wins = bdh_jac > tfm_jac  if not (np.isnan(bdh_jac) or np.isnan(tfm_jac)) else None

    def _badge(bdh_wins):
        if bdh_wins is None: return ""
        if bdh_wins:
            return ('<span class="score-badge-win">BDH WINS</span>'
                    '&nbsp;<span class="score-badge-lose">TFM</span>')
        return ('<span class="score-badge-win">TFM WINS</span>'
                '&nbsp;<span class="score-badge-lose">BDH</span>')

    def _vc(v, is_win):
        cls = "score-val-win" if is_win else "score-val-lose"
        return cls

    # Sparsity: lower active% = better
    sp_max = max(bdh_sp, tfm_sp) + 1 if not (np.isnan(bdh_sp) or np.isnan(tfm_sp)) else 100
    sp_b_bar = _bar_html(bdh_sp, C_BDH, sp_max)
    sp_t_bar = _bar_html(tfm_sp, C_TFM, sp_max)
    sp_b_cls  = _vc(bdh_sp, sp_bdh_wins)
    sp_t_cls  = _vc(tfm_sp, not sp_bdh_wins if sp_bdh_wins is not None else False)

    # Entropy: lower = better
    ent_max = max(bdh_ent, tfm_ent) + 0.1 if not (np.isnan(bdh_ent) or np.isnan(tfm_ent)) else 14
    ent_b_bar = _bar_html(bdh_ent, C_BDH, ent_max)
    ent_t_bar = _bar_html(tfm_ent, C_TFM, ent_max)
    ent_b_cls  = _vc(bdh_ent, ent_bdh_wins)
    ent_t_cls  = _vc(tfm_ent, not ent_bdh_wins if ent_bdh_wins is not None else False)

    # Jaccard: higher = better
    jac_max = max(bdh_jac, tfm_jac) + 0.05 if not (np.isnan(bdh_jac) or np.isnan(tfm_jac)) else 1
    jac_b_bar = _bar_html(bdh_jac * 100, C_BDH, jac_max * 100)
    jac_t_bar = _bar_html(tfm_jac * 100, C_TFM, jac_max * 100)
    jac_b_cls  = _vc(bdh_jac, jac_bdh_wins)
    jac_t_cls  = _vc(tfm_jac, not jac_bdh_wins if jac_bdh_wins is not None else False)

    lang_str = ", ".join(LANG_TAGS[l] for l in LANGS) if LANGS else "—"

    # Count wins (among measured metrics only)
    measured_wins = [w for w in [sp_bdh_wins, ent_bdh_wins, jac_bdh_wins]
                     if w is not None]
    bdh_wins   = sum(measured_wins)
    total_meas = len(measured_wins)

    if bdh_wins == total_meas and total_meas == 3:
        v_color  = "#38e090"
        v_border = "#38e090"
        verdict  = f"BDH wins all {total_meas}/{total_meas} measured metrics — Architectural monosemanticity confirmed."
    elif bdh_wins >= 2:
        v_color  = "#38e090"
        v_border = "#38e090"
        verdict  = f"BDH leads {bdh_wins}/{total_meas} measured metrics — Strong evidence of architectural monosemanticity."
    elif bdh_wins == 1:
        v_color  = "#ffc040"
        v_border = "#ffc040"
        verdict  = f"BDH leads {bdh_wins}/{total_meas} metrics — Try a different layer or increase TOPK."
    else:
        v_color  = "#ff5a72"
        v_border = "#ff5a72"
        verdict  = f"BDH leads {bdh_wins}/{total_meas} metrics — Check layer selection and checkpoint."

    sp_note    = "lower % active = sparser = better for BDH"
    ent_note   = "lower bits = energy concentrated = monosemantic"
    jac_note   = f"higher overlap = same neurons across DE/FR/ES/IT"

    html = f"""
<div class="score-card">
  <!-- Column headers -->
  <div class="score-hdr" style="display:grid;grid-template-columns:30% 30% 30% 10%;
      border-bottom:1px solid rgba(55,90,200,.30);padding-bottom:.6rem;margin-bottom:.4rem;">
    <div>METRIC</div>
    <div style="color:#38e090;">BDH (measured)</div>
    <div style="color:#ff5a72;">Transformer (measured)</div>
    <div>WINNER</div>
  </div>

  <!-- Row 1: Sparsity -->
  <div class="score-metric-row">
    <div>
      <div class="score-metric-name">Activation Sparsity</div>
      <div class="score-metric-sub">% (token, neuron) pairs active<br>
        BDH: ReLU &gt; 0 &nbsp;|&nbsp; TFM: |GELU| &gt; {tfm_thresh:.3f}</div>
      <div class="score-metric-sub" style="color:#38a0ff;margin-top:3px;">{sp_note}</div>
    </div>
    <div>
      <div class="{sp_b_cls}">{bdh_sp:.1f}%</div>
      <div class="score-metric-sub" style="color:#38e090;">ReLU — exact zeros</div>
      {sp_b_bar}
    </div>
    <div>
      <div class="{sp_t_cls}">{tfm_sp:.1f}%</div>
      <div class="score-metric-sub" style="color:#ff5a72;">GELU — near-universal</div>
      {sp_t_bar}
    </div>
    <div>{_badge(sp_bdh_wins)}</div>
  </div>

  <!-- Row 2: Entropy -->
  <div class="score-metric-row">
    <div>
      <div class="score-metric-name">Activation Entropy</div>
      <div class="score-metric-sub">Shannon H = -sum(p log2 p) bits<br>PRIMARY monosemanticity metric</div>
      <div class="score-metric-sub" style="color:#38a0ff;margin-top:3px;">{ent_note}</div>
    </div>
    <div>
      <div class="{ent_b_cls}">{bdh_ent:.2f} bits</div>
      <div class="score-metric-sub" style="color:#38e090;">concentrated</div>
      {ent_b_bar}
    </div>
    <div>
      <div class="{ent_t_cls}">{tfm_ent:.2f} bits</div>
      <div class="score-metric-sub" style="color:#ff5a72;">diffuse</div>
      {ent_t_bar}
    </div>
    <div>{_badge(ent_bdh_wins)}</div>
  </div>

  <!-- Row 3: Jaccard -->
  <div class="score-metric-row">
    <div>
      <div class="score-metric-name">Cross-lingual Jaccard</div>
      <div class="score-metric-sub">EN ↔ {lang_str} top-K overlap<br>
        BDH TopK={TOPK_SET_BDH} &nbsp;|&nbsp; TFM TopK={TOPK_SET_TFM}</div>
      <div class="score-metric-sub" style="color:#38a0ff;margin-top:3px;">{jac_note}</div>
    </div>
    <div>
      <div class="{jac_b_cls}">{bdh_jac:.3f}</div>
      <div class="score-metric-sub" style="color:#38e090;">same neurons across langs</div>
      {jac_b_bar}
    </div>
    <div>
      <div class="{jac_t_cls}">{tfm_jac:.3f}</div>
      <div class="score-metric-sub" style="color:#ff5a72;">different neurons per lang</div>
      {jac_t_bar}
    </div>
    <div>{_badge(jac_bdh_wins)}</div>
  </div>

  <!-- Row 4: Interpretability -->
  <div class="score-metric-row">
    <div>
      <div class="score-metric-name">Interpretability Source</div>
      <div class="score-metric-sub">How readable are neurons?</div>
    </div>
    <div>
      <div class="score-val-win">Architectural</div>
      <div class="score-metric-sub" style="color:#38e090;">ReLU + Hebbian — by design</div>
    </div>
    <div>
      <div class="score-val-lose">Post-hoc</div>
      <div class="score-metric-sub" style="color:#ff5a72;">requires SAE decomposition</div>
    </div>
    <div><span class="score-badge-win">BDH</span></div>
  </div>

  <!-- Row 5: Memory -->
  <div class="score-metric-row">
    <div>
      <div class="score-metric-name">Memory Scaling</div>
      <div class="score-metric-sub">With sequence length T</div>
    </div>
    <div>
      <div class="score-val-win">O(n &times; d)</div>
      <div class="score-metric-sub" style="color:#38e090;">constant — Hebbian table</div>
    </div>
    <div>
      <div class="score-val-lose">O(T&sup2;)</div>
      <div class="score-metric-sub" style="color:#ff5a72;">KV-cache grows with T</div>
    </div>
    <div><span class="score-badge-win">BDH</span></div>
  </div>

  <!-- Verdict -->
  <div class="verdict-strip" style="border:2px solid {v_border};
      background:rgba(8,16,48,.90);margin-top:1.4rem;">
    <span style="color:{v_color};">
      BDH wins {bdh_wins}/{total_meas} measured metrics  &mdash;&nbsp;{verdict}
    </span>
  </div>
</div>"""
    return html


# ==============================================================
# SIDEBAR
# ==============================================================
with st.sidebar:
    st.markdown("""
<div style="margin-bottom:.6rem;">
  <div style="font-family:'JetBrains Mono',monospace;font-weight:700;
      font-size:.96rem;color:#c4dcff;">BDH MONOSEMANTICITY</div>
  <div style="font-size:.66rem;color:#304880;letter-spacing:.14em;margin-top:2px;">
      KRITI 2026 | PATH B | INTERPRETABILITY</div>
</div>""", unsafe_allow_html=True)
    st.markdown("---")
    st.markdown("### Model Paths")
    BDH_SCRIPT_PATH = st.text_input(
        "BDH script path", value="/content/bdh_europarl_train_probe.py")
    CKPT_PATH = st.text_input(
        "Checkpoint path", value="checkpoints/bdh_europarl_bytes.pt")
    st.markdown("### Probe Parameters")
    layers_in  = st.text_input("Layers (comma-separated)", value="4,5")
    TOPK_PRINT = st.number_input("TOPK_PRINT (Demo TopK)",
                                 min_value=1, max_value=50, value=5)
    st.markdown("### TopK Settings")
    TOPK_SET_BDH = st.number_input(
        "TopK neurons — BDH",
        min_value=10, max_value=5000, value=200,
        help="Number of top-K BDH neurons used for Jaccard overlap and candidate selection.")
    TOPK_SET_TFM = st.number_input(
        "TopK neurons — Transformer",
        min_value=10, max_value=3072, value=200,
        help="Number of top-K Transformer neurons (out of 3072 MLP hidden dim).")
    MAX_CAND   = st.number_input("MAX_CAND",
                                 min_value=50, max_value=5000, value=800)
    PASS_Z     = st.slider("PASS_Z threshold (Tab 3)", 0.0, 5.0, 1.5, 0.1)
    USE_AUG    = st.checkbox("POS sentence augmentation", value=True)
    st.markdown("### Sparsity Threshold")
    TFM_THRESH_INPUT = st.slider(
        "Transformer |GELU| inactive threshold",
        min_value=0.001, max_value=0.10, value=TFM_SPARSE_THRESH,
        step=0.001, format="%.3f",
        help=(
            "GELU produces small-magnitude negative values for negative pre-activations. "
            "Neurons with |activation| below this threshold are counted as inactive. "
            "0.02 is a sensible default; increase to count fewer neurons as active."
        )
    )
    st.markdown("### Languages")
    LANGS = st.multiselect("Target languages", options=list(ALL_LANGS.keys()),
                           default=list(ALL_LANGS.keys()))
    st.markdown("### Comparison Settings")
    CMP_LAYER_BDH = st.number_input("BDH layer for comparison",
                                    min_value=0, max_value=5, value=4)
    CMP_LAYER_TFM = st.number_input("TFM layer for comparison",
                                    min_value=0, max_value=5, value=3)

LAYERS = [int(p.strip()) for p in layers_in.split(",") if p.strip().isdigit()]

# Keep backward-compat alias for tabs 1/2/3 that use a single TOPK_SET
TOPK_SET = TOPK_SET_BDH


# ==============================================================
# HEADER
# ==============================================================
st.markdown("# BDH Monosemanticity Probe")
st.markdown(
    '<div class="subtitle">Kriti 2026 | Path B: Interpretability Showcases'
    ' | BDH vs Transformer Comparative Study</div>',
    unsafe_allow_html=True
)
st.markdown("""<div class="info-strip">
This tool demonstrates that <strong>BDH synapses are monosemantic</strong> — each synapse
responds to a single semantic concept consistently across languages — while Transformer neurons
are <strong>polysemantic</strong> (encoding many unrelated concepts simultaneously).
ReLU sparse activations produce exact zeros for inactive neurons.
GELU activations are near-universal — almost every neuron fires on every token,
making the Transformer uninterpretable without post-hoc SAE methods.
</div>""", unsafe_allow_html=True)


# ==============================================================
# LOAD MODELS
# ==============================================================
try:
    device, cfg, model, N, ids_for_text, neuron_id = load_bdh(
        BDH_SCRIPT_PATH, CKPT_PATH)
    LAYERS = [L for L in LAYERS if 0 <= L < cfg.n_layer]
    c1, c2, c3, c4, c5 = st.columns(5)
    c1.metric("Device",    device.upper())
    c2.metric("BDH layers", cfg.n_layer)
    c3.metric("n_head",    cfg.n_head)
    c4.metric("n_embd",    cfg.n_embd)
    c5.metric("N/head",    N)
    if not LAYERS:
        st.warning("No valid layers selected.")
        st.stop()
except Exception as exc:
    st.error(f"BDH load failed: {exc}")
    st.info("Set the correct paths in the sidebar.")
    st.stop()

with st.spinner("Loading distilGPT-2 for comparison..."):
    try:
        tfm_mdl, tfm_tok = load_transformer(device)
        st.success("distilGPT-2 loaded — Transformer comparison ready")
    except Exception as e:
        st.warning(f"Transformer load failed: {e}. Comparison tab limited.")
        tfm_mdl, tfm_tok = None, None

st.markdown("---")


# ==============================================================
# TABS
# ==============================================================
tab_demo, tab_dataset, tab_neurons, tab_compare = st.tabs([
    "DEMO   Meaning-aligned Overlap",
    "DATASET   Jaccard vs Baseline",
    "NEURONS   POS vs NEG Selectivity",
    "BDH vs TRANSFORMER   Comparative Interpretability",
])


# --------------------------------------------------------------
# TAB 1 — DEMO
# --------------------------------------------------------------
with tab_demo:
    st.markdown("## Demo: Meaning-aligned Sparse Feature Overlap")
    st.markdown("""<div class="info-strip">
    Enter an English sentence. The tool translates it into selected languages, then checks
    whether the <em>same</em> BDH neurons fire for semantically aligned tokens — proving
    monosemantic cross-lingual encoding.
    </div>""", unsafe_allow_html=True)

    demo_sent = st.text_input("English word or sentence",
                               value="This girl is beautiful.", key="demo_input")
    run_demo  = st.button("Run Demo", key="btn_demo")

    if run_demo:
        if not LANGS:
            st.warning("Select at least one language.")
            st.stop()

        with st.spinner("Translating..."):
            translations = {l: translate(l, demo_sent, device) for l in LANGS}

        st.markdown("### Full translations")
        tcols = st.columns(1 + len(LANGS))
        tcols[0].markdown(f"**EN**\n\n`{demo_sent}`")
        for i, (lang, txt) in enumerate(translations.items()):
            tcols[i + 1].markdown(
                f"**{LANG_TAGS.get(lang, lang[:2].upper())}**\n\n`{txt}`")

        def best_match_span(lang_tokens, lang_sets, en_set,
                            i_en, n_en, max_span=3):
            if not lang_tokens:
                return 0, 0, 0, []
            n      = len(lang_tokens)
            j0     = int(round(i_en * (n - 1) / (n_en - 1))) if n_en > 1 else 0
            window = max(2, min(5, n // 2))
            s_lo   = max(0, j0 - window)
            s_hi   = min(n - 1, j0 + window)
            best   = None
            for s in range(s_lo, s_hi + 1):
                U = set()
                for span_len in range(1, max_span + 1):
                    e = s + span_len
                    if e > n:
                        break
                    U |= (lang_sets[e - 1] if lang_sets[e - 1] else set())
                    ov     = len(en_set & U) if en_set else 0
                    sh     = sorted(list(en_set & U))[:5]
                    center = s + 0.5 * (span_len - 1)
                    score  = 1000 * ov - 20 * abs(center - j0) + 2 * span_len
                    if best is None or score > best[0]:
                        best = (score, ov, span_len, s, sh)
            _, ov, span_len, s, sh = best
            return s, s + span_len, int(max(ov, 0)), sh

        def compute_demo(en_text, trans_dict, layer_idx, topk):
            en_words_disp  = wsplit(en_text)
            en_words_clean = [clean_token(w) or w for w in en_words_disp]
            n_en    = len(en_words_disp)
            t2t     = {LANG_TAGS[lang]: txt for lang, txt in trans_dict.items()}
            ltoks_disp  = {tag: wsplit(txt) for tag, txt in t2t.items()}
            ltoks_clean = {tag: [clean_token(w) or w for w in toks]
                           for tag, toks in ltoks_disp.items()}
            tags    = list(ltoks_disp.keys())
            en_lists = [topk_list(wc, layer_idx, topk, device, cfg, model,
                                  N, ids_for_text, neuron_id) if wc else []
                        for wc in en_words_clean]
            en_sets  = [set(l) for l in en_lists]
            lsets    = {tag: [topk_set(tc, layer_idx, topk, device, cfg,
                                       model, N, ids_for_text, neuron_id)
                               if tc else set()
                               for tc in ltoks_clean[tag]]
                        for tag in tags}
            plot_rows = []
            per_token = {}
            for i, (en_disp, en_clean) in enumerate(
                    zip(en_words_disp, en_words_clean)):
                enS  = en_sets[i]
                en5  = ", ".join(map(str, en_lists[i][:5])) if en_lists[i] else ""
                prow = {"EN_word": en_disp}
                trows = []
                for tag in tags:
                    toks_disp = ltoks_disp[tag]
                    sets      = lsets[tag]
                    if not toks_disp:
                        prow[f"{tag}_overlap"] = 0
                        trows.append({"Language": tag, "Full sentence": "—",
                                      "Matched phrase": "—", "Overlap": 0,
                                      "Shared neurons (Top-5)": "—"})
                        continue
                    s_idx, e_idx, bov, bsh = best_match_span(
                        toks_disp, sets, enS, i_en=i, n_en=n_en)
                    matched_phrase = " ".join(toks_disp[s_idx:e_idx]).strip()
                    prow[f"{tag}_overlap"] = int(bov)
                    trows.append({
                        "Language":             tag,
                        "Matched phrase":       matched_phrase,
                        "Overlap":              int(bov),
                        "Shared neurons (Top-5)":
                            ", ".join(map(str, bsh)) if bsh else "—",
                        "Full sentence":        t2t[tag],
                    })
                plot_rows.append(prow)
                per_token[en_disp] = {"en5": en5, "rows": trows}
            df_plot = pd.DataFrame(plot_rows)
            df_en   = pd.DataFrame([
                {"EN_idx": i, "EN_token": w,
                 "EN_Top5_neurons": per_token[w]["en5"]}
                for i, w in enumerate(en_words_disp)
            ])
            return df_plot, tags, df_en, per_token, en_words_disp

        for L in LAYERS:
            st.markdown(f"### Layer {L}")
            with st.spinner(f"Layer {L}: computing overlaps..."):
                df_sum, tags, df_en, per_token, en_words = compute_demo(
                    demo_sent, translations, L, int(TOPK_PRINT))
            st.markdown("#### EN token — Top neurons")
            st.dataframe(df_en, use_container_width=True,
                         height=min(280, 60 + 40 * len(df_en)))
            f1 = plot_line(df_sum, tags,
                           f"Meaning-aligned token overlap  |  Layer {L}  |  TopK={TOPK_PRINT}")
            st.pyplot(f1, use_container_width=True); plt.close(f1)
            f2 = plot_heat(df_sum, tags,
                           f"Overlap heatmap  |  Layer {L}  |  TopK={TOPK_PRINT}")
            st.pyplot(f2, use_container_width=True); plt.close(f2)
            st.markdown("#### Per-token breakdown")
            for en_w in en_words:
                d = per_token[en_w]
                st.markdown(
                    f'<div class="token-hdr">EN token: <strong>{en_w}</strong>'
                    f'&nbsp;|&nbsp;Top-5 EN neurons: {d["en5"] or "—"}</div>',
                    unsafe_allow_html=True)
                df_rows = pd.DataFrame(d["rows"])[
                    ["Language", "Matched phrase", "Overlap",
                     "Shared neurons (Top-5)", "Full sentence"]]
                st.dataframe(df_rows, use_container_width=True,
                             height=min(460, 60 + 55 * len(df_rows)))
            st.markdown("---")


# --------------------------------------------------------------
# TAB 2 — DATASET
# --------------------------------------------------------------
with tab_dataset:
    st.markdown("## Dataset Mode: Jaccard Overlap vs Shuffle Baseline")
    st.markdown('<div class="info-strip">Jaccard overlap between EN sparse feature sets '
                'and translations vs a shuffle baseline proves that BDH neurons align '
                'semantically, not randomly.</div>', unsafe_allow_html=True)

    col_c, col_n = st.columns(2)
    with col_c:
        st.markdown("**Concept words — one per line (min 5)**")
        ct = st.text_area(
            "", value="doctor\nhospital\nsurgery\npatient\nmedicine",
            height=190, key="cta")
    with col_n:
        st.markdown("**NEG unrelated words — one per line (min 5)**")
        nt = st.text_area(
            "", value="football\nbanana\nguitar\nmountain\npolitics",
            height=190, key="nta")

    n_shuf = st.number_input("Shuffle baseline trials", 5, 200, 20, key="ns")
    run_ds = st.button("Run Dataset Analysis", key="btn_ds")

    if run_ds:
        concepts = [x.strip() for x in ct.splitlines() if x.strip()]
        NEG      = [x.strip() for x in nt.splitlines() if x.strip()]
        if len(concepts) < 5 or len(NEG) < 5:
            st.warning("Provide at least 5 concept words and 5 NEG words.")
        elif not LANGS:
            st.warning("Select at least one language.")
        else:
            with st.spinner("Translating concept words..."):
                concept_trans = [
                    {"EN": w, **{l: translate(l, w, device) for l in LANGS}}
                    for w in concepts
                ]
            st.markdown("### Concept translation table")
            st.dataframe(pd.DataFrame(concept_trans)[["EN"] + list(LANGS)],
                         use_container_width=True, height=300)
            st.session_state["concept_trans"] = concept_trans
            st.session_state["NEG"]           = NEG

            def compute_actual(L):
                names, actual = [], []
                for ct2 in concept_trans:
                    names.append(ct2["EN"])
                    Sen = topk_set(ct2["EN"], L, int(TOPK_SET), device,
                                   cfg, model, N, ids_for_text, neuron_id)
                    actual.append(float(np.mean([
                        jaccard(Sen, topk_set(ct2[l], L, int(TOPK_SET), device,
                                             cfg, model, N, ids_for_text, neuron_id))
                        for l in LANGS
                    ])))
                return names, np.array(actual)

            def compute_shuffle(L, ns):
                n    = len(concept_trans)
                base = []
                for _ in range(ns):
                    perm = np.random.permutation(n)
                    vals = []
                    for i in range(n):
                        Sen = topk_set(
                            concept_trans[i]["EN"], L, int(TOPK_SET),
                            device, cfg, model, N, ids_for_text, neuron_id)
                        vals.append(float(np.mean([
                            jaccard(Sen, topk_set(
                                concept_trans[perm[i]][l], L, int(TOPK_SET),
                                device, cfg, model, N, ids_for_text, neuron_id))
                            for l in LANGS
                        ])))
                    base.append(np.mean(vals))
                return np.array(base)

            results = {}
            for L in LAYERS:
                with st.spinner(f"Layer {L}: computing..."):
                    names, actual_arr = compute_actual(L)
                    base_dist         = compute_shuffle(L, int(n_shuf))
                    results[L]        = (names, actual_arr, base_dist)
                gap = actual_arr.mean() - base_dist.mean()
                a, b, c_ = st.columns(3)
                a.metric(f"L{L} actual mean",   f"{actual_arr.mean():.3f}")
                b.metric(f"L{L} baseline mean", f"{base_dist.mean():.3f}")
                c_.metric(f"L{L} gap",          f"{gap:+.3f}")
                fb = plot_baseline(
                    names, actual_arr, base_dist,
                    f"Layer {L}: Jaccard vs Shuffle Baseline  |  "
                    f"TopK={TOPK_SET}  |  shuffles={int(n_shuf)}")
                st.pyplot(fb, use_container_width=True); plt.close(fb)

            st.session_state["ds_results"] = results
            st.session_state["ds_layers"]  = LAYERS
            st.success("Dataset analysis complete. Proceed to NEURONS tab.")


# --------------------------------------------------------------
# TAB 3 — NEURONS
# --------------------------------------------------------------
with tab_neurons:
    st.markdown("## Neuron Selectivity: POS vs NEG Activations")
    st.markdown('<div class="info-strip">Finds the single best neuron per concept '
                '(z = (pos_mean - neg_mean) / neg_std). PASS if z >= PASS_Z.</div>',
                unsafe_allow_html=True)

    if "concept_trans" not in st.session_state:
        st.info("Run Dataset Mode first to build the concept list.")
    else:
        concept_trans = st.session_state["concept_trans"]
        NEG           = st.session_state["NEG"]
        results       = st.session_state.get("ds_results", {})
        ds_layers     = st.session_state.get("ds_layers", LAYERS)
        run_n         = st.button("Find Best Neurons and Plot", key="btn_neu")

        if run_n:
            TOP_SHOW = min(3, len(concept_trans))

            def best_neuron(i, L):
                trans = concept_trans[i]
                POS   = [trans["EN"]] + [trans[l] for l in LANGS]
                if USE_AUG:
                    POS += pos_wrap(trans["EN"])
                NEG_S = [f"The text is about {n}." for n in NEG]
                inter = None
                for t in POS:
                    S     = topk_set(t, L, int(TOPK_SET), device, cfg,
                                     model, N, ids_for_text, neuron_id)
                    inter = S if inter is None else inter & S
                Sen  = topk_set(trans["EN"], L, int(TOPK_SET), device, cfg,
                                model, N, ids_for_text, neuron_id)
                cand = list((set(inter or set()) | set(Sen)))[:int(MAX_CAND)]
                if not cand:
                    return None
                best = None
                for gid in cand:
                    pv  = np.array([act_mean(t, L, gid, device, cfg, model,
                                            N, ids_for_text) for t in POS])
                    nv  = np.array([act_mean(t, L, gid, device, cfg, model,
                                            N, ids_for_text) for t in NEG_S])
                    pm, nm = pv.mean(), nv.mean()
                    ns_ = nv.std() + 1e-6
                    sel = pm - nm
                    z   = sel / ns_
                    if best is None or (z, sel) > best["score"]:
                        best = dict(gid=gid, pos_vals=pv, neg_vals=nv,
                                    pos_mean=float(pm), neg_mean=float(nm),
                                    neg_std=float(ns_), sel=float(sel),
                                    z=float(z), score=(z, sel))
                return best

            for L in ds_layers:
                if L not in results:
                    st.write(f"Layer {L}: no data.")
                    continue
                names, actual_arr, base_dist = results[L]
                order = np.argsort(-(actual_arr - base_dist.mean()))[:TOP_SHOW]
                st.markdown(f"### Layer {L} — top {TOP_SHOW} concepts by margin")
                for idx in order:
                    concept = concept_trans[idx]["EN"]
                    with st.spinner(f"Layer {L}: '{concept}'..."):
                        best = best_neuron(idx, L)
                    if best is None:
                        st.warning(f"'{concept}': no candidate neurons.")
                        continue
                    gid     = best["gid"]
                    verdict = "PASS" if best["z"] >= PASS_Z else "WEAK"
                    badge   = ('<span class="badge-pass">PASS</span>'
                               if verdict == "PASS"
                               else '<span class="badge-weak">WEAK</span>')
                    st.markdown(
                        f'<span class="concept-row">{concept}</span>'
                        f'<span class="neuron-tag">neuron {gid}</span>'
                        f'<span class="neuron-tag">decode {decode_gid(gid, cfg, N)}</span>'
                        f'&nbsp;&nbsp;{badge}',
                        unsafe_allow_html=True)
                    m1, m2, m3, m4 = st.columns(4)
                    m1.metric("pos_mean",    f"{best['pos_mean']:.4f}")
                    m2.metric("neg_mean",    f"{best['neg_mean']:.4f}")
                    m3.metric("selectivity", f"{best['sel']:.4f}")
                    m4.metric("z-score",     f"{best['z']:.2f}")
                    pl = (["EN"] + [LANG_TAGS[l] for l in LANGS]
                          + (["P1", "P2", "P3"] if USE_AUG else []))
                    nl = [f"NEG:{w}" for w in NEG]
                    fb = plot_pn_bar(
                        pl, best["pos_vals"], nl, best["neg_vals"],
                        f"POS vs NEG  |  '{concept}'  |  layer {L}  "
                        f"|  neuron {gid}  |  {verdict}")
                    st.pyplot(fb, use_container_width=True); plt.close(fb)
                    fx = plot_pn_box(
                        best["pos_vals"], best["neg_vals"],
                        f"Activation distribution  |  '{concept}'  "
                        f"|  layer {L}  |  neuron {gid}  |  {verdict}")
                    st.pyplot(fx, use_container_width=True); plt.close(fx)
                    st.markdown("---")
            st.success("Neuron analysis complete.")


# ==============================================================
# TAB 4 — BDH vs TRANSFORMER
# ==============================================================
with tab_compare:
    st.markdown("## BDH vs Transformer: Comparative Interpretability Study")
    st.markdown("""<div class="info-strip">
    <strong>A — Sparsity:</strong>
      BDH uses ReLU — exact zeros.
      Transformer uses GELU — neurons are near-universally active.
      <em>Sparsity is measured correctly for each architecture</em>:
      BDH threshold = 0 (exact); Transformer threshold = |GELU| &gt; threshold
      (adjust in sidebar).<br>
    <strong>B — Activation Entropy:</strong>
      Shannon entropy of activation magnitudes. Lower = concentrated = monosemantic.
      This is the primary monosemanticity metric.<br>
    <strong>C — Cross-lingual Jaccard:</strong>
      Same concept, four languages. BDH retains same neurons; Transformer scatters.<br>
    <strong>D — Heatmap:</strong> Middle token only — preserves true ReLU zeros.<br>
    <strong>E — Summary Scorecard:</strong>
      All metrics with winner badges and comparison bars.
    </div>""", unsafe_allow_html=True)

    # Expected-results box
    st.markdown(f"""<div class="expected-box">
    <strong>Expected results with a well-trained BDH model:</strong><br>
    &bull; Sparsity &mdash; BDH: <span class="win">~5-30% active</span> &nbsp;|&nbsp;
    Transformer: <span class="lose">~80-99% active</span>
    (using |GELU| &gt; {TFM_THRESH_INPUT:.3f}) &nbsp;
    <em>Adjust the threshold slider in the sidebar if the Transformer appears too sparse.</em><br>
    &bull; Entropy &mdash; BDH: <span class="win">~3-8 bits</span> &nbsp;|&nbsp;
    Transformer: <span class="lose">~9-12 bits</span><br>
    &bull; Jaccard &mdash; BDH: <span class="win">0.70-0.90</span> &nbsp;|&nbsp;
    Transformer: <span class="lose">0.05-0.30</span><br>
    &bull; Heatmap &mdash; BDH: <span class="win">dark background, bright isolated spots</span> &nbsp;|&nbsp;
    Transformer: <span class="lose">uniformly lit, no clear structure</span>
    </div>""", unsafe_allow_html=True)

    if tfm_mdl is None:
        st.error("Transformer model not loaded. Check your connection and re-run.")
        st.stop()

    col_a, col_b = st.columns(2)
    with col_a:
        st.markdown("**Concept words** (5-10 for best results)")
        cmp_concepts_raw = st.text_area(
            "", value="doctor\nhospital\nsurgery\npatient\nmedicine\ndiagnosis",
            height=190, key="cmp_c")
    with col_b:
        st.markdown("**NEG (unrelated) words**")
        cmp_neg_raw = st.text_area(
            "", value="football\nbanana\nguitar\nmountain\npolitics\nocean",
            height=190, key="cmp_n")

    st.markdown("**Section A — Sparsity sentence** (type any sentence you want to analyse)")
    sparsity_sentence = st.text_input(
        "",
        value="The doctor treated the patient carefully.",
        key="sparsity_sent",
        help="This exact sentence is passed to both BDH and Transformer for sparsity measurement. Type anything you like — try domain-specific sentences, short phrases, or random text to see how sparsity changes."
    )

    run_cmp = st.button("Run Full Comparison", key="btn_cmp")

    if run_cmp:
        cmp_concepts = [x.strip() for x in cmp_concepts_raw.splitlines() if x.strip()]
        cmp_neg      = [x.strip() for x in cmp_neg_raw.splitlines() if x.strip()]
        if len(cmp_concepts) < 3:
            st.warning("Enter at least 3 concept words.")
            st.stop()

        summary = {}

        # ----------------------------------------------------------
        # SECTION A — SPARSITY (architecture-aware measurement)
        # ----------------------------------------------------------
        st.markdown("---")
        st.markdown("### A  —  Sparsity (Architecture-Aware Per-Token Measurement)")
        st.markdown(f"""<div class="compare-box">
        <span class="bdh-label">BDH</span>: ReLU produces exact zeros.
        A neuron is inactive iff activation = 0 exactly.
        Sparsity = fraction of (token, neuron) pairs with activation &gt; 0.<br>
        <span class="tfm-label">Transformer</span>: GELU(x) can be negative for negative x,
        not zero. Near-universal activation is the norm.
        A neuron is inactive iff |GELU| &le; {TFM_THRESH_INPUT:.3f}
        (adjust in sidebar &rarr; Sparsity Threshold).
        </div>""", unsafe_allow_html=True)

        sample_text = sparsity_sentence.strip() or f"The {cmp_concepts[0]} treated the patient carefully."
        with st.spinner(f"Computing per-token activations for: \"{sample_text[:60]}\"..."):
            bdh_tok_mat = bdh_token_acts(
                sample_text, int(CMP_LAYER_BDH), device, model, ids_for_text)
            tfm_tok_mat = tfm_token_acts(
                tfm_mdl, tfm_tok, sample_text,
                min(int(CMP_LAYER_TFM), 5), device)

        fig_sp, pct_b, pct_t = plot_sparsity_comparison(
            bdh_tok_mat, tfm_tok_mat,
            f"Sparsity  |  \"{sample_text[:55]}...\"  "
            f"|  BDH layer {CMP_LAYER_BDH}  vs  Transformer layer {CMP_LAYER_TFM}",
            tfm_thresh=TFM_THRESH_INPUT)
        st.pyplot(fig_sp, use_container_width=True); plt.close(fig_sp)

        sc1, sc2, sc3 = st.columns(3)
        sc1.metric("BDH active per token (ReLU > 0)",
                   f"{pct_b:.1f}%")
        sc2.metric(f"Transformer active per token (|GELU| > {TFM_THRESH_INPUT:.3f})",
                   f"{pct_t:.1f}%")
        winner_sp = "BDH sparser" if pct_b < pct_t else "TFM sparser"
        sc3.metric("Density ratio (TFM / BDH)",
                   f"{pct_t / (pct_b + 0.1):.1f}x",
                   delta=winner_sp)
        summary["bdh_sparsity"] = pct_b
        summary["tfm_sparsity"] = pct_t

        # Warn if sparsity direction seems wrong
        if pct_b >= pct_t:
            st.warning(
                f"BDH appears denser than Transformer ({pct_b:.1f}% vs {pct_t:.1f}%). "
                f"Try: (1) increasing the Transformer threshold slider above "
                f"{TFM_THRESH_INPUT:.3f}, (2) selecting a different BDH layer, "
                f"or (3) verifying the checkpoint is fully trained."
            )

        # ----------------------------------------------------------
        # SECTION B — ACTIVATION ENTROPY
        # ----------------------------------------------------------
        st.markdown("---")
        st.markdown("### B  —  Activation Entropy (Primary Monosemanticity Metric)")
        st.markdown("""<div class="compare-box">
        <strong>H = -sum(p log2 p)</strong> over the normalised activation magnitude
        distribution for each concept (mean across tokens, then entropy of vector).
        Monosemantic neurons concentrate energy in very few positions: low entropy.
        Polysemantic neurons spread energy everywhere: high entropy.
        </div>""", unsafe_allow_html=True)

        with st.spinner("Computing activation entropy for all concepts..."):
            bdh_ents = []
            tfm_ents = []
            for c in cmp_concepts:
                bv = bdh_token_acts(
                    c, int(CMP_LAYER_BDH), device, model, ids_for_text).mean(0)
                tv = tfm_token_acts(
                    tfm_mdl, tfm_tok, c,
                    min(int(CMP_LAYER_TFM), 5), device).mean(0)
                bdh_ents.append(activation_entropy(bv))
                tfm_ents.append(activation_entropy(tv))

        fig_ent = plot_entropy_comparison(
            cmp_concepts, bdh_ents, tfm_ents,
            f"Activation Entropy  (lower = monosemantic)  "
            f"|  BDH layer {CMP_LAYER_BDH}  vs  Transformer layer {CMP_LAYER_TFM}")
        st.pyplot(fig_ent, use_container_width=True); plt.close(fig_ent)

        e1, e2, e3 = st.columns(3)
        e1.metric("BDH mean entropy",         f"{np.mean(bdh_ents):.2f} bits")
        e2.metric("Transformer mean entropy",  f"{np.mean(tfm_ents):.2f} bits")
        e3.metric("Entropy advantage (BDH lower by)",
                  f"{np.mean(tfm_ents) - np.mean(bdh_ents):+.2f} bits")

        st.dataframe(pd.DataFrame({
            "Concept":            cmp_concepts,
            "BDH entropy (bits)": [f"{e:.2f}" for e in bdh_ents],
            "TFM entropy (bits)": [f"{e:.2f}" for e in tfm_ents],
            "Winner":             ["BDH" if b < t else "TFM"
                                   for b, t in zip(bdh_ents, tfm_ents)],
        }), use_container_width=True)
        summary["bdh_entropy"] = np.mean(bdh_ents)
        summary["tfm_entropy"] = np.mean(tfm_ents)

        # ----------------------------------------------------------
        # SECTION C — CROSS-LINGUAL JACCARD
        # ----------------------------------------------------------
        if LANGS:
            st.markdown("---")
            st.markdown("### C  —  Cross-lingual Consistency")
            st.markdown(f"""<div class="compare-box">
            Each concept is translated into all selected languages.
            Jaccard overlap between EN top-K and translation top-K is computed
            and averaged across languages.<br>
            BDH TopK = {TOPK_SET_BDH} &nbsp;|&nbsp;
            Transformer TopK = {TOPK_SET_TFM} (out of 3072 MLP neurons).<br>
            <span class="bdh-label">BDH</span>: encodes meaning, not surface form
            &rarr; same neurons across languages.<br>
            <span class="tfm-label">Transformer</span>: encodes token statistics
            &rarr; different neurons per language.
            </div>""", unsafe_allow_html=True)

            with st.spinner("Computing cross-lingual Jaccard scores..."):
                bdh_j_rows = []
                tfm_j_rows = []
                trans_table = []   # store for display
                for c in cmp_concepts:
                    lang_bdh_j = []
                    lang_tfm_j = []
                    row = {"EN (concept)": c}
                    en_set_bdh = topk_set(
                        c, int(CMP_LAYER_BDH), int(TOPK_SET_BDH),
                        device, cfg, model, N, ids_for_text, neuron_id)
                    en_vec_tfm = tfm_token_acts(
                        tfm_mdl, tfm_tok, c,
                        min(int(CMP_LAYER_TFM), 5), device).mean(0)
                    en_topk_tfm = set(
                        np.argsort(-np.abs(en_vec_tfm))[:int(TOPK_SET_TFM)])
                    for lang in LANGS:
                        tr = translate(lang, c, device)
                        tag = LANG_TAGS[lang]
                        row[f"{tag} translation"] = tr
                        tr_set = topk_set(
                            tr, int(CMP_LAYER_BDH), int(TOPK_SET_BDH),
                            device, cfg, model, N, ids_for_text, neuron_id)
                        j_bdh = jaccard(en_set_bdh, tr_set)
                        lang_bdh_j.append(j_bdh)
                        row[f"{tag} BDH Jacc"] = f"{j_bdh:.3f}"
                        tr_vec = tfm_token_acts(
                            tfm_mdl, tfm_tok, tr,
                            min(int(CMP_LAYER_TFM), 5), device).mean(0)
                        tr_topk = set(
                            np.argsort(-np.abs(tr_vec))[:int(TOPK_SET_TFM)])
                        j_tfm = jaccard(en_topk_tfm, tr_topk)
                        lang_tfm_j.append(j_tfm)
                        row[f"{tag} TFM Jacc"] = f"{j_tfm:.3f}"
                    bdh_j_rows.append(np.mean(lang_bdh_j))
                    tfm_j_rows.append(np.mean(lang_tfm_j))
                    row["BDH mean Jacc"] = f"{np.mean(lang_bdh_j):.3f}"
                    row["TFM mean Jacc"] = f"{np.mean(lang_tfm_j):.3f}"
                    row["Winner"] = "BDH" if np.mean(lang_bdh_j) > np.mean(lang_tfm_j) else "TFM"
                    trans_table.append(row)

            # Show translation table FIRST so user can see what was translated
            st.markdown("#### Translations used + per-language Jaccard scores")
            st.dataframe(pd.DataFrame(trans_table), use_container_width=True,
                         height=min(500, 60 + 50 * len(trans_table)))

            fig_xl = plot_crosslingual_consistency(
                cmp_concepts, np.array(bdh_j_rows), np.array(tfm_j_rows),
                f"Cross-lingual Consistency  "
                f"|  BDH TopK={TOPK_SET_BDH}  vs  TFM TopK={TOPK_SET_TFM}"
                f"  |  layers {CMP_LAYER_BDH} / {CMP_LAYER_TFM}")
            st.pyplot(fig_xl, use_container_width=True); plt.close(fig_xl)

            cl1, cl2, cl3 = st.columns(3)
            cl1.metric("BDH mean Jaccard",        f"{np.mean(bdh_j_rows):.3f}")
            cl2.metric("Transformer mean Jaccard", f"{np.mean(tfm_j_rows):.3f}")
            cl3.metric("BDH advantage",
                       f"{np.mean(bdh_j_rows) - np.mean(tfm_j_rows):+.3f}")
            summary["bdh_jaccard"] = np.mean(bdh_j_rows)
            summary["tfm_jaccard"] = np.mean(tfm_j_rows)

        # ----------------------------------------------------------
        # SECTION D — HEATMAP
        # ----------------------------------------------------------
        st.markdown("---")
        st.markdown("### D  —  Activation Heatmap (Middle Token)")
        st.markdown("""<div class="compare-box">
        Middle token of each concept phrase — not averaged — preserves ReLU exact zeros.<br>
        <span class="bdh-label">BDH</span>: Left half = top active neurons (bright spots).
        Right half = zero/inactive neurons (dark). The yellow dashed line is the boundary.
        True sparsity is only visible when inactive neurons are included alongside active ones.<br>
        <span class="tfm-label">Transformer</span>: All top neurons shown — uniformly lit
        because GELU never produces exact zeros.
        </div>""", unsafe_allow_html=True)

        with st.spinner("Building activation matrices (middle token)..."):
            bdh_vecs_mid = []
            tfm_vecs_mid = []
            for c in cmp_concepts:
                bm = bdh_token_acts(
                    c, int(CMP_LAYER_BDH), device, model, ids_for_text)
                tm = tfm_token_acts(
                    tfm_mdl, tfm_tok, c, min(int(CMP_LAYER_TFM), 5), device)
                bdh_vecs_mid.append(bm[bm.shape[0] // 2])
                tfm_vecs_mid.append(tm[tm.shape[0] // 2])

        fig_hm = plot_activation_heatmap_compare(
            cmp_concepts,
            np.stack(bdh_vecs_mid),
            np.stack(tfm_vecs_mid),
            n_show=80)
        st.pyplot(fig_hm, use_container_width=True); plt.close(fig_hm)

        # ----------------------------------------------------------
        # SECTION E — SUMMARY SCORECARD
        # ----------------------------------------------------------
        st.markdown("---")
        st.markdown("### E  —  Summary Scorecard")

        scorecard_html = render_summary_table(
            summary,
            TOPK_SET_BDH=int(TOPK_SET_BDH),
            TOPK_SET_TFM=int(TOPK_SET_TFM),
            LANGS=LANGS,
            tfm_thresh=TFM_THRESH_INPUT,
        )
        st.markdown(scorecard_html, unsafe_allow_html=True)
        st.success("Comparison complete. All metrics measured from live model activations.")
'''

import os
with open("/content/app.py", "w") as fh:
    fh.write(APP_CODE)

size   = os.path.getsize("/content/app.py")
source = open("/content/app.py").read()

checks = {
    "four tabs present":                   "tab_compare" in source,
    "separate TOPK_SET_BDH":               "TOPK_SET_BDH" in source,
    "separate TOPK_SET_TFM":               "TOPK_SET_TFM" in source,
    "tfm sparsity uses abs()":             "np.abs(token_acts_matrix) > thresh" in source,
    "bdh sparsity exact zero":             "bdh_sparsity_pct" in source,
    "tfm threshold slider in sidebar":     "TFM_THRESH_INPUT" in source,
    "entropy function defined":            "activation_entropy" in source,
    "entropy in Tab 4":                    "bdh_ents" in source,
    "z-score removed from Tab 4":          "bdh_best_z_fair" not in source,
    "render_summary_table function":       "render_summary_table" in source,
    "winner badges in scorecard":          "score-badge-win" in source,
    "value bars in scorecard":             "score-bar-fill" in source,
    "expected-results box":                "expected-box" in source,
    "single-token heatmap":                "bm.shape[0] // 2" in source,
    "sparsity direction warning":          "BDH appears denser" in source,
    "no emojis in logic":                  all(ch not in source
                                               for ch in ["\U0001F7E2", "\U0001F534"]),
}

print(f"app.py written  —  {size:,} bytes")
for label, passed in checks.items():
    status = "OK  " if passed else "FAIL"
    print(f"  {status}  {label}")

In [None]:
!pip install pyngrok -q

from google.colab import userdata
from pyngrok import ngrok

NGROK_TOKEN = userdata.get('NGROK_TOKEN')
ngrok.set_auth_token(NGROK_TOKEN)
print('ngrok ready')

In [None]:


import os
import subprocess
import time

# --------------------------------------------------------------
# Install Python dependencies
# Pinned versions avoid API-breaking changes between runs.
# --------------------------------------------------------------
os.system(
    "pip install -q "
    "pyngrok "
    "streamlit "
    "imageio "
    "transformers "
    "sentencepiece "
    "accelerate "
    "torch "
    "numpy "
    "pandas "
    "matplotlib "
)

# --------------------------------------------------------------
# Kill any previous Streamlit process that may still be running
# from an earlier cell execution, to avoid port conflicts.
# --------------------------------------------------------------
os.system("pkill -f streamlit 2>/dev/null || true")
time.sleep(2)

# --------------------------------------------------------------
# Verify that app.py was written by Cell 1 before proceeding.
# --------------------------------------------------------------
APP_PATH = "/content/app.py"
if not os.path.exists(APP_PATH):
    raise FileNotFoundError(
        "app.py not found at /content/app.py. "
        "Please run Cell 1 first to write the application file."
    )

print(f"app.py confirmed — {os.path.getsize(APP_PATH):,} bytes")

# --------------------------------------------------------------
# Start Streamlit as a background subprocess.
# --server.headless disables the browser-open prompt.
# --server.enableCORS false is required for ngrok tunnelling.
# --server.enableXsrfProtection false avoids CSRF header issues
#   in the Colab iframe environment.
# --------------------------------------------------------------
server_process = subprocess.Popen([
    "streamlit", "run", APP_PATH,
    "--server.port",                "8501",
    "--server.headless",            "true",
    "--server.address",             "0.0.0.0",
    "--server.enableCORS",          "false",
    "--server.enableXsrfProtection","false",
])

# Allow Streamlit time to initialise before creating the tunnel
print("Waiting for Streamlit to start...")
time.sleep(8)

# Check that the subprocess is still alive
if server_process.poll() is not None:
    raise RuntimeError(
        "Streamlit process exited unexpectedly. "
        "Check for import errors in app.py."
    )

print("Streamlit server running on port 8501")

# --------------------------------------------------------------
# Create an ngrok HTTPS tunnel so the app is reachable from
# outside the Colab VM.
# Replace the token below with your own from https://ngrok.com
# if the current one has expired.
# --------------------------------------------------------------
from pyngrok import ngrok

try:
    ngrok.kill()
except Exception:
    pass

ngrok.set_auth_token(NGROK_TOKEN)

tunnel = ngrok.connect(8501, bind_tls=True, inspect=False)
public_url = tunnel.public_url

print(f"\nApp is running at: {public_url}")
print("Open the link above in a new browser tab, or view the inline frame below.\n")

# --------------------------------------------------------------
# Display the app inline inside the Colab output cell.
# A wider frame (1200 px) avoids horizontal scrollbars on most
# laptop screens.
# --------------------------------------------------------------
from IPython.display import IFrame, display

display(IFrame(public_url, width=1200, height=820))