# SplatNLP Colab Demo

This notebook downloads pretrained artifacts from DigitalOcean Spaces and runs:

- Token-space inference (multi-label set completion)
- Constraint-aware beam search to produce a legal build
- (Optional) Ultra + SAE hook introspection


In [1]:
import os
import subprocess
import sys
from pathlib import Path

REPO_URL = "https://github.com/cesaregarza/SplatNLP.git"
REPO_DIR = Path("/content/SplatNLP")

if not REPO_DIR.exists():
    subprocess.run(
        ["git", "clone", "--depth", "1", REPO_URL, str(REPO_DIR)],
        check=True,
    )

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR / "src"))
print("Repo ready:", REPO_DIR)


Cloning into '/content/SplatNLP'...


Repo ready: /content/SplatNLP


In [2]:
!pip -q install requests tqdm



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Download artifacts (Full model)

This grabs `model.pth`, `model_params.json`, `vocab.json`, `weapon_vocab.json`, and `model_info.json`.

If you want the Ultra + SAE demo later in this notebook, set `DOWNLOAD_ULTRA_SAE = True`.

Optional: if you upload feature labels to the same dataset directory (either `feature_labels_ultra.json` or `consolidated_ultra.json`), the Ultra section will display human-readable names (and optionally notes) for top features.


In [3]:
from pathlib import Path

from splatnlp.utils.download_artifacts import (
    CORE_ARTIFACTS,
    ULTRA_ARTIFACTS,
    ULTRA_LABELS_ARTIFACTS,
    ULTRA_SAE_ARTIFACTS,
    download_artifacts,
)

BASE_URL = "https://splat-nlp.nyc3.cdn.digitaloceanspaces.com"
DATASET_DIR = "dataset_v2"
OUT_DIR = Path("saved_models") / DATASET_DIR

DOWNLOAD_ULTRA_SAE = False

artifacts = list(CORE_ARTIFACTS)
if DOWNLOAD_ULTRA_SAE:
    artifacts += (
        list(ULTRA_ARTIFACTS)
        + list(ULTRA_SAE_ARTIFACTS)
        + list(ULTRA_LABELS_ARTIFACTS)
    )

download_artifacts(
    base_url=BASE_URL,
    dataset_dir=DATASET_DIR,
    out_dir=OUT_DIR,
    artifacts=artifacts,
    force=False,
    timeout_s=180,
    quiet=False,
    dry_run=False,
)

print("Downloaded to:", OUT_DIR)
for p in sorted(OUT_DIR.iterdir()):
    print(" -", p.name)


get   https://splat-nlp.nyc3.cdn.digitaloceanspaces.com/dataset_v2/model.pth


model.pth: 317MB [00:30, 10.8MB/s] 


get   https://splat-nlp.nyc3.cdn.digitaloceanspaces.com/dataset_v2/model_info.json


model_info.json: 82.0B [00:00, 36.6kB/s]


get   https://splat-nlp.nyc3.cdn.digitaloceanspaces.com/dataset_v2/model_params.json


model_params.json: 272B [00:00, 119kB/s]


get   https://splat-nlp.nyc3.cdn.digitaloceanspaces.com/dataset_v2/vocab.json


vocab.json: 3.68kB [00:00, 3.42MB/s]


get   https://splat-nlp.nyc3.cdn.digitaloceanspaces.com/dataset_v2/weapon_vocab.json


weapon_vocab.json: 3.01kB [00:00, 3.00MB/s]

Downloaded to: saved_models/dataset_v2
 - model.pth
 - model_info.json
 - model_params.json
 - vocab.json
 - weapon_vocab.json





## Load model + run token-space inference

This model is **multi-label set completion**: it outputs an independent probability for every token (not an autoregressive next-token LM).

Tokenization note:
- Standard abilities are represented as cumulative threshold tokens (e.g., 12 AP `run_speed_up` → `run_speed_up_3`, `_6`, `_12`).
- Beam search below runs on capstones (highest tier per ability family), then expands them back to cumulative tokens when calling the model.


In [4]:
import json

import torch

from splatnlp.model.models import SetCompletionModel
from splatnlp.serve.tokenize import tokenize_build
from splatnlp.utils.constants import (
    BUCKET_THRESHOLDS,
    MAIN_ONLY_ABILITIES,
    NULL,
    PAD,
    STANDARD_ABILITIES,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

params = json.loads((OUT_DIR / "model_params.json").read_text())
vocab = json.loads((OUT_DIR / "vocab.json").read_text())
weapon_vocab = json.loads((OUT_DIR / "weapon_vocab.json").read_text())

pad_token_id = params.get("pad_token_id", vocab[PAD])
inv_vocab = {v: k for k, v in vocab.items()}

model = SetCompletionModel(**params)
model.load_state_dict(torch.load(OUT_DIR / "model.pth", map_location="cpu"))
model.to(device)
model.eval()

weapon_id = (
    "weapon_id_8000"
    if "weapon_id_8000" in weapon_vocab
    else next(iter(weapon_vocab))
)

partial_build = {
    "ink_saver_main": 6,
    "run_speed_up": 12,
    "intensify_action": 12,
}

def build_to_capstones(build: dict[str, int]) -> list[str]:
    capstones: list[str] = []
    for ability in MAIN_ONLY_ABILITIES:
        if build.get(ability):
            capstones.append(ability)

    for ability in STANDARD_ABILITIES:
        ap = build.get(ability)
        if ap is None:
            continue
        ap = int(ap)
        thresh = max((t for t in BUCKET_THRESHOLDS if t <= ap), default=None)
        if thresh is None:
            continue
        capstones.append(f"{ability}_{thresh}")

    return capstones or [NULL]

def predict_probs(context_tokens: list[str]) -> torch.Tensor:
    x = torch.tensor([[vocab[t] for t in context_tokens]], device=device)
    w = torch.tensor([[weapon_vocab[weapon_id]]], device=device)
    mask = x == pad_token_id
    with torch.no_grad():
        probs = torch.sigmoid(model(x, w, key_padding_mask=mask)).squeeze(0)
    return probs.detach().cpu()

def predict_top_tokens(context_tokens: list[str], k: int = 15):
    probs_cpu = predict_probs(context_tokens)
    skip = {vocab.get(PAD), vocab.get(NULL)}
    top = torch.topk(probs_cpu, k=min(k, probs_cpu.numel()))
    out = []
    for idx, p in zip(top.indices.tolist(), top.values.tolist()):
        if idx in skip:
            continue
        out.append((inv_vocab[idx], float(p)))
    return out

baseline_tokens = [NULL]
partial_tokens = tokenize_build(partial_build)
capstone_tokens = build_to_capstones(partial_build)

print("weapon_id:", weapon_id)
print("partial_build:", partial_build)
print("partial_tokens (cumulative):", partial_tokens)
print("capstone_tokens (for beam search):", capstone_tokens)

print("\nTop predictions from <NULL> (baseline):")
for tok, p in predict_top_tokens(baseline_tokens):
    print(f"{tok:<32} {p:.4f}")

print("\nTop predictions from partial build tokens:")
for tok, p in predict_top_tokens(partial_tokens):
    print(f"{tok:<32} {p:.4f}")

baseline_probs = predict_probs(baseline_tokens)
partial_probs = predict_probs(partial_tokens)
delta = partial_probs - baseline_probs

def top_delta_tokens(delta_vec: torch.Tensor, k: int, *, largest: bool):
    skip = {vocab.get(PAD), vocab.get(NULL)}
    values, indices = torch.topk(
        delta_vec if largest else -delta_vec,
        k=min(k, delta_vec.numel()),
    )
    out = []
    for idx, v in zip(indices.tolist(), values.tolist()):
        if idx in skip:
            continue
        dv = float(v) if largest else -float(v)
        out.append((inv_vocab[idx], dv))
    return out

print("\nTop +Δ tokens (partial - baseline):")
for tok, dv in top_delta_tokens(delta, k=12, largest=True):
    print(f"{tok:<32} {dv:+.4f}")

print("\nTop -Δ tokens (partial - baseline):")
for tok, dv in top_delta_tokens(delta, k=12, largest=False):
    print(f"{tok:<32} {dv:+.4f}")


Device: cuda
weapon_id: weapon_id_8000
partial_build: {'ink_saver_main': 6, 'run_speed_up': 12, 'intensify_action': 12}
partial_tokens (cumulative): ['ink_saver_main_3', 'ink_saver_main_6', 'intensify_action_3', 'intensify_action_6', 'intensify_action_12', 'run_speed_up_3', 'run_speed_up_6', 'run_speed_up_12']
capstone_tokens (for beam search): ['ink_saver_main_6', 'intensify_action_12', 'run_speed_up_12']

Top predictions from <NULL> (baseline):
swim_speed_up_3                  0.7586
stealth_jump                     0.7355
swim_speed_up_6                  0.6821
quick_super_jump_3               0.6694
ink_saver_main_3                 0.5337
quick_respawn_3                  0.4763
ink_saver_main_6                 0.4574
ink_resistance_up_3              0.4455
quick_respawn_6                  0.4443
comeback                         0.3900
ink_recovery_up_3                0.3535
quick_respawn_12                 0.3513
swim_speed_up_12                 0.3208
quick_respawn_15             

## Beam search: token-space → legal build-space

We run two demos:

- Baseline build from `<NULL>` only
- Completion given a partial build (capstone tokens)

This uses the repo’s constraint-aware reconstruction (allocator + beam search).

Tip: set `SHOW_TRACE = True` to print a compact step-by-step trace.


In [5]:
from splatnlp.utils.reconstruct.allocator import Allocator
from splatnlp.utils.reconstruct.beam_search import reconstruct_build

import time

allocator = Allocator()
vocab_size = len(vocab)

def predict_fn(current_tokens: list[str], weapon_id: str):
    ids = [vocab[t] for t in current_tokens]
    x = torch.tensor([ids], device=device)
    w = torch.tensor([[weapon_vocab[weapon_id]]], device=device)
    mask = x == pad_token_id
    with torch.no_grad():
        probs = torch.sigmoid(model(x, w, key_padding_mask=mask)).squeeze(0)
    probs = probs.detach().cpu().tolist()
    return {inv_vocab[i]: float(probs[i]) for i in range(vocab_size)}

def summarize_trace(trace, top_preds: int = 8) -> None:
    seen = set()
    for fr in trace:
        caps = sorted(fr.partial_caps.keys())
        added = sorted(set(caps) - seen)
        seen.update(caps)
        top = sorted(fr.logits.items(), key=lambda kv: kv[1], reverse=True)
        top = [(t, round(p, 4)) for t, p in top if not t.startswith("<")][:top_preds]
        print({"step": fr.step, "added": added, "top_preds": top})

def print_build(label: str, build) -> None:
    d = build.to_dict()
    achieved = dict(
        sorted(d["achieved_ap"].items(), key=lambda kv: kv[1], reverse=True)
    )
    print(f"\n{label}")
    print("total_ap:", d["total_ap"])
    print("mains:", d["mains"])
    print("subs:", d["subs"])
    print("achieved_ap:", achieved)

SHOW_TRACE = False

def run_beam(label: str, initial_context: list[str]):
    start = time.perf_counter()
    if SHOW_TRACE:
        builds, traces = reconstruct_build(
            predict_fn=predict_fn,
            weapon_id=weapon_id,
            initial_context=initial_context,
            allocator=allocator,
            beam_size=5,
            max_steps=6,
            top_k=1,
            record_traces=True,
        )
    else:
        builds = reconstruct_build(
            predict_fn=predict_fn,
            weapon_id=weapon_id,
            initial_context=initial_context,
            allocator=allocator,
            beam_size=5,
            max_steps=6,
            top_k=1,
            record_traces=False,
        )
        traces = None

    elapsed = time.perf_counter() - start
    print(f"{label} runtime: {elapsed:.3f}s")

    if not builds:
        raise RuntimeError(f"No valid build produced for: {label}")

    build = builds[0]
    print_build(label, build)

    if SHOW_TRACE:
        summarize_trace(traces[0])

    return build

baseline_build = run_beam("Baseline build (<NULL>)", [NULL])
completion_build = run_beam("Completion from partial capstones", capstone_tokens)

def print_ap_delta(a, b, *, max_rows: int = 15) -> None:
    a_ap = a.to_dict()["achieved_ap"]
    b_ap = b.to_dict()["achieved_ap"]
    rows = []
    for ability in sorted(set(a_ap) | set(b_ap)):
        av = int(a_ap.get(ability, 0))
        bv = int(b_ap.get(ability, 0))
        if av == bv:
            continue
        rows.append((ability, av, bv, bv - av))

    rows.sort(key=lambda r: abs(r[3]), reverse=True)
    print("\nAP changes (completion - baseline):")
    for ability, av, bv, dv in rows[:max_rows]:
        print(f"{ability:<24} {av:>2} -> {bv:>2} ({dv:+d})")

def check_constraints(build, requested: dict[str, int]) -> None:
    achieved = build.to_dict()["achieved_ap"]
    ok = True
    print("\nConstraint check (completion build):")
    for ability, req in requested.items():
        got = int(achieved.get(ability, 0))
        met = got >= int(req)
        ok = ok and met
        status = "OK" if met else "FAIL"
        print(f"{ability:<24} req={req:>2} got={got:>2} {status}")
    print("all_met:", ok)

print_ap_delta(baseline_build, completion_build)
check_constraints(completion_build, partial_build)


Baseline build (<NULL>) runtime: 1.014s

Baseline build (<NULL>)
total_ap: 57
mains: {'head': 'comeback', 'clothes': 'ink_saver_main', 'shoes': 'stealth_jump'}
subs: {'sub_power_up': 2, 'swim_speed_up': 2, 'special_power_up': 2, 'quick_super_jump': 1, 'ink_resistance_up': 1, 'ink_recovery_up': 1}
achieved_ap: {'comeback': 10, 'ink_saver_main': 10, 'stealth_jump': 10, 'sub_power_up': 6, 'swim_speed_up': 6, 'special_power_up': 6, 'quick_super_jump': 3, 'ink_resistance_up': 3, 'ink_recovery_up': 3}
Completion from partial capstones runtime: 0.903s

Completion from partial capstones
total_ap: 57
mains: {'head': 'intensify_action', 'clothes': 'ink_saver_main', 'shoes': 'run_speed_up'}
subs: {'ink_saver_main': 2, 'intensify_action': 2, 'run_speed_up': 2, 'quick_super_jump': 2, 'ink_resistance_up': 1}
achieved_ap: {'intensify_action': 16, 'ink_saver_main': 16, 'run_speed_up': 16, 'quick_super_jump': 6, 'ink_resistance_up': 3}

AP changes (completion - baseline):
intensify_action          0 ->

## Build ↔ token visualization

This renders a 57-AP build as **3 mains + 9 subs**, then shows how it collapses to token-space:

- **Beam-search capstones** (highest tier per ability family)
- **Cumulative threshold tokens** (what the model actually sees)

Sub-slot ordering is not unique; this is just a display assignment.


In [6]:
from collections import Counter
import html as _html

from IPython.display import HTML, display

from splatnlp.serve.tokenize import tokenize_build
from splatnlp.utils.constants import (
    BUCKET_THRESHOLDS,
    MAIN_ONLY_ABILITIES,
    NULL,
    STANDARD_ABILITIES,
)

def _abbrev(token: str) -> str:
    special = {
        "ink_recovery_up": "iru",
        "ink_resistance_up": "res",
        "ink_saver_main": "ism",
        "ink_saver_sub": "iss",
        "quick_respawn": "qr",
        "quick_super_jump": "qsj",
        "run_speed_up": "rsu",
        "sub_resistance_up": "sru",
        "special_charge_up": "scu",
        "special_power_up": "spu",
        "special_saver": "ss",
        "sub_power_up": "bpu",
        "swim_speed_up": "ssu",
        "intensify_action": "ia",
    }
    base = token
    suffix = ""
    if "_" in token and token.rsplit("_", 1)[-1].isdigit():
        base, suffix = token.rsplit("_", 1)
    short = special.get(base) or "".join(w[0] for w in base.split("_"))
    return f"{short}_{suffix}" if suffix else short

def _pretty(token: str) -> str:
    return token.replace("_", " ").title()

def _ap_to_capstones(ap: dict[str, int]) -> list[str]:
    caps: list[str] = []
    for ability in MAIN_ONLY_ABILITIES:
        if int(ap.get(ability, 0) or 0) > 0:
            caps.append(ability)
    for ability in STANDARD_ABILITIES:
        val = int(ap.get(ability, 0) or 0)
        thresh = max(
            (t for t in BUCKET_THRESHOLDS if t <= val),
            default=None,
        )
        if thresh is not None:
            caps.append(f"{ability}_{thresh}")
    return caps or [NULL]

def _palette(n: int) -> list[str]:
    base = [
        "#4c78a8",
        "#f58518",
        "#54a24b",
        "#e45756",
        "#72b7b2",
        "#b279a2",
        "#ff9da6",
        "#9d755d",
        "#bab0ac",
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
    ]
    if n <= len(base):
        return base[:n]
    return [base[i % len(base)] for i in range(n)]

def _build_color_map(*build_dicts: dict) -> dict[str, str]:
    fams: set[str] = set()
    for d in build_dicts:
        fams.update(d.get("achieved_ap", {}).keys())
        for v in d.get("mains", {}).values():
            if v is not None:
                fams.add(v)
        fams.update(d.get("subs", {}).keys())
    fam_list = sorted(fams)
    cols = _palette(len(fam_list))
    return {fam: col for fam, col in zip(fam_list, cols)}

def _expand_subs(subs: dict[str, int]) -> list[str | None]:
    items = []
    for fam, c in sorted(
        subs.items(),
        key=lambda kv: (-int(kv[1]), kv[0]),
    ):
        items.extend([fam] * int(c))
    if len(items) < 9:
        items.extend([None] * (9 - len(items)))
    return items[:9]

def _render_build_card(title: str, build, *, colors: dict[str, str]) -> str:
    d = build.to_dict()
    mains = d["mains"]
    subs = _expand_subs(d["subs"])
    subs_by_slot = {
        "head": subs[0:3],
        "clothes": subs[3:6],
        "shoes": subs[6:9],
    }

    ap = {k: int(v) for k, v in d["achieved_ap"].items()}
    cumulative = tokenize_build(ap)
    capstones = _ap_to_capstones(ap)

    def slot_box(label: str | None, kind: str) -> str:
        if label is None:
            return f"<div class='slot {kind} empty'>empty</div>"
        col = colors.get(label, "#eee")
        display_name = _html.escape(_pretty(label))
        short = _html.escape(_abbrev(label))
        full = _html.escape(label)
        return (
            f"<div class='slot {kind}' title='{full}' "
            f"style='background:{col}30;border-color:{col}88'>"
            f"<div class='slot-name'>{display_name}</div>"
            f"<div class='slot-abbrev'>{short}</div>"
            "</div>"
        )

    gear_cols = []
    for slot_name in ["head", "clothes", "shoes"]:
        main = mains.get(slot_name)
        col_html = [
            "<div class='gear-col'>",
            f"<div class='gear-title'>{slot_name.title()}</div>",
            slot_box(main, "main"),
        ]
        for sub in subs_by_slot[slot_name]:
            col_html.append(slot_box(sub, "sub"))
        col_html.append("</div>")
        gear_cols.append("\n".join(col_html))

    main_counts = Counter([m for m in mains.values() if m is not None])
    rows = []
    for fam in sorted(ap.keys()):
        ap_val = int(ap[fam])
        m = int(main_counts.get(fam, 0))
        s = int(d["subs"].get(fam, 0))
        col = colors.get(fam, "#eee")

        toks = []
        cap = None
        if fam in MAIN_ONLY_ABILITIES:
            toks = [fam]
            cap = fam
        elif fam in STANDARD_ABILITIES:
            toks = [f"{fam}_{t}" for t in BUCKET_THRESHOLDS if t <= ap_val]
            cap = toks[-1] if toks else None

        chip_html = []
        for t in toks:
            short = _html.escape(_abbrev(t))
            full = _html.escape(t)
            cls = "chip cap" if t == cap else "chip"
            bg = f"{col}26" if t == cap else f"{col}14"
            chip_html.append(
                f"<span class='{cls}' title='{full}' "
                f"style='border-color:{col}99;background:{bg}'>{short}</span>"
            )

        rows.append(
            "\n".join(
                [
                    "<tr>",
                    (
                        "<td>"
                        f"<span class='fam' title='{_html.escape(fam)}' "
                        f"style='background:{col}22;border-color:{col}88'>"
                        f"{_html.escape(_abbrev(fam).upper())}</span>"
                        f"<span class='fam-full' title='{_html.escape(fam)}'>"
                        f"{_html.escape(_pretty(fam))}</span>"
                        "</td>"
                    ),
                    f"<td>{m}</td>",
                    f"<td>{s}</td>",
                    f"<td>{ap_val}</td>",
                    f"<td>{''.join(chip_html) if chip_html else ''}</td>",
                    "</tr>",
                ]
            )
        )

    def list_as_code(items: list[str]) -> str:
        safe = ", ".join(_html.escape(x) for x in items)
        return f"<code>{safe}</code>"

    cap_n = len([t for t in capstones if t != NULL])
    cum_n = len([t for t in cumulative if t != NULL])
    return "\n".join(
        [
            "<div class='flow-card'>",
            f"<div class='flow-title'>{_html.escape(title)}</div>",
            (
                f"<div class='small'>total_ap={int(d['total_ap'])} | "
                f"capstones={cap_n} | cumulative_tokens={cum_n}</div>"
            ),
            "<div class='gear-grid'>",
            *gear_cols,
            "</div>",
            "<details class='details'><summary>Beam-search capstones</summary>",
            list_as_code(capstones),
            "</details>",
            "<details class='details'><summary>Tokens fed to model (cumulative)</summary>",
            list_as_code(cumulative),
            "</details>",
            "<div class='small' style='margin-top:8px'>",
            "Per-family breakdown (capstone token is bold):",
            "</div>",
            "<table class='tbl'>",
            "<thead><tr><th>ability</th><th>m</th><th>s</th><th>AP</th><th>tokens</th></tr></thead>",
            "<tbody>",
            *rows,
            "</tbody></table>",
            "<div class='small note'>",
            "Note: sub-slot order is not unique; this is a display assignment.",
            "</div>",
            "</div>",
        ]
    )

STYLE = """
<style>
.flow-wrap{font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto,Arial;}
.flow-row{display:flex;gap:14px;flex-wrap:wrap;align-items:flex-start;}
.flow-card{border:1px solid #e6e6e6;border-radius:14px;padding:12px;width:520px;box-shadow:0 1px 0 rgba(0,0,0,0.02);}
.flow-title{font-weight:700;font-size:14px;margin-bottom:2px;}
.small{color:#555;font-size:12px;}
.gear-grid{display:grid;grid-template-columns:repeat(3,1fr);gap:10px;margin-top:10px;}
.gear-col{background:#fafafa;border:1px solid #eee;border-radius:12px;padding:8px;}
.gear-title{font-weight:600;font-size:12px;color:#333;margin-bottom:6px;}
.slot{border-radius:10px;padding:7px 8px;margin-top:6px;border:1px solid rgba(0,0,0,0.10);}
.slot-name{font-size:12px;font-weight:800;line-height:1.15;}
.slot.main .slot-name{font-size:13px;}
.slot.sub .slot-name{font-size:12px;font-weight:700;}
.slot-abbrev{margin-top:2px;font-size:11px;color:#222;opacity:0.85;font-family:ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",monospace;}
.slot.empty{background:#f0f0f0;color:#888;border-color:#e0e0e0;}
.details{margin-top:10px;}
.details summary{cursor:pointer;color:#333;font-size:12px;font-weight:600;}
code{background:#f6f6f6;border:1px solid #eee;padding:2px 6px;border-radius:8px;font-size:12px;}
.tbl{width:100%;border-collapse:collapse;margin-top:6px;}
.tbl th,.tbl td{border-bottom:1px solid #eee;padding:4px 6px;font-size:12px;vertical-align:top;}
.tbl th{color:#333;text-align:left;font-weight:700;}
.fam{display:inline-block;padding:2px 8px;border-radius:999px;border:1px solid rgba(0,0,0,0.12);font-weight:800;font-size:11px;}
.fam-full{margin-left:6px;font-size:12px;color:#222;font-weight:600;}
.chip{display:inline-block;margin:2px 4px 2px 0;padding:2px 8px;border-radius:999px;border:1px solid rgba(0,0,0,0.12);font-size:12px;}
.chip.cap{border-width:2px;font-weight:800;}
.note{margin-top:8px;}
</style>
"""

baseline_d = baseline_build.to_dict()
completion_d = completion_build.to_dict()
color_map = _build_color_map(baseline_d, completion_d)

html_out = "\n".join(
    [
        "<div class='flow-wrap'>",
        STYLE,
        "<div class='flow-row'>",
        _render_build_card("Baseline build", baseline_build, colors=color_map),
        _render_build_card("Completion build", completion_build, colors=color_map),
        "</div>",
        "</div>",
    ]
)
display(HTML(html_out))


ability,m,s,AP,tokens
CComeback,1,0,10,c
IRUInk Recovery Up,0,1,3,iru_3
RESInk Resistance Up,0,1,3,res_3
ISMInk Saver Main,1,0,10,ism_3ism_6
QSJQuick Super Jump,0,1,3,qsj_3
SPUSpecial Power Up,0,2,6,spu_3spu_6
SJStealth Jump,1,0,10,sj
BPUSub Power Up,0,2,6,bpu_3bpu_6
SSUSwim Speed Up,0,2,6,ssu_3ssu_6

ability,m,s,AP,tokens
RESInk Resistance Up,0,1,3,res_3
ISMInk Saver Main,1,2,16,ism_3ism_6ism_12ism_15
IAIntensify Action,1,2,16,ia_3ia_6ia_12ia_15
QSJQuick Super Jump,0,2,6,qsj_3qsj_6
RSURun Speed Up,1,2,16,rsu_3rsu_6rsu_12rsu_15


## Optional: Ultra + SAE hook introspection

This section is off by default. It loads `model_ultra.pth` and `sae_model_ultra.pth`, registers a hook on the 512D pooled representation, and shows:

- Hook fidelity stats (bypass vs SAE reconstruction)
- Top active SAE features for a run
- Top tokens most influenced by a feature
- (If available) human labels for features

Note: `no_change=True` captures SAE activations without changing model outputs; `no_change=False` inserts the SAE reconstruction (used in the fidelity check).

If you have label notes available, set `SHOW_LABEL_NOTES = True` in the code cell.

To enable: set `DOWNLOAD_ULTRA_SAE = True` in the download cell above, then set `RUN_ULTRA_SAE = True` below.


In [7]:
RUN_ULTRA_SAE = False

if RUN_ULTRA_SAE:
    import json

    import torch

    import torch.nn.functional as F

    from splatnlp.monosemantic_sae.hooks import register_hooks
    from splatnlp.monosemantic_sae.models import SparseAutoencoder
    from splatnlp.utils.reconstruct.allocator import Allocator
    from splatnlp.utils.reconstruct.beam_search import reconstruct_build

    ultra = SetCompletionModel(**params)
    ultra.load_state_dict(
        torch.load(OUT_DIR / "model_ultra.pth", map_location="cpu")
    )
    ultra.to(device)
    ultra.eval()

    sae_cfg = json.loads((OUT_DIR / "sae_config_ultra.json").read_text())
    sae = SparseAutoencoder(
        input_dim=int(sae_cfg.get("input_dim", 512)),
        expansion_factor=float(sae_cfg.get("expansion_factor", 48.0)),
        l1_coefficient=float(sae_cfg.get("l1_coefficient", 0.0)),
        target_usage=float(sae_cfg.get("target_usage", 0.0)),
        usage_coeff=float(sae_cfg.get("usage_coeff", 0.0)),
        dead_neuron_threshold=float(sae_cfg.get("dead_neuron_threshold", 1e-6)),
        dead_neuron_steps=int(sae_cfg.get("dead_neuron_steps", 12500)),
    )
    sae.load_state_dict(
        torch.load(OUT_DIR / "sae_model_ultra.pth", map_location="cpu")
    )
    sae.to(device)
    sae.eval()

    hook, handle = register_hooks(
        ultra,
        sae_model=sae,
        bypass=False,
        no_change=True,
    )

    vocab_size = len(vocab)

    SHOW_LABEL_NOTES = False
    MAX_NOTE_CHARS = 220

    def _compress_ws(text: str) -> str:
        return " ".join((text or "").strip().split())

    def load_ultra_feature_labels() -> dict[int, dict[str, str]]:
        candidates = [
            OUT_DIR / "consolidated_ultra.json",
            OUT_DIR / "feature_labels_ultra.json",
        ]
        for path in candidates:
            if not path.exists():
                continue

            raw = json.loads(path.read_text())
            if not isinstance(raw, dict):
                continue

            # Consolidated schema (feature_id/display_name/dashboard_notes/...)
            consolidated = any(
                isinstance(v, dict)
                and (
                    "feature_id" in v
                    or "display_name" in v
                    or "dashboard_name" in v
                    or "dashboard_notes" in v
                )
                for v in raw.values()
            )

            labels: dict[int, dict[str, str]] = {}
            if consolidated:
                for k, v in raw.items():
                    if not isinstance(v, dict):
                        continue
                    fid = v.get("feature_id")
                    if fid is None:
                        try:
                            fid = int(k)
                        except (TypeError, ValueError):
                            continue
                    labels[int(fid)] = {
                        "name": v.get("display_name")
                        or v.get("dashboard_name")
                        or "",
                        "category": v.get("dashboard_category") or "none",
                        "notes": v.get("dashboard_notes") or "",
                    }
                return labels

            # Dashboard schema (FeatureLabel: name/category/notes/timestamp)
            for k, v in raw.items():
                if not isinstance(v, dict):
                    continue
                try:
                    fid = int(k)
                except ValueError:
                    continue
                labels[fid] = {
                    "name": v.get("name") or "",
                    "category": v.get("category") or "none",
                    "notes": v.get("notes") or "",
                }
            return labels

        return {}

    feature_labels = load_ultra_feature_labels()
    if feature_labels:
        print("Loaded feature labels:", len(feature_labels))

    def forward_probs(tokens: list[str], weapon_id: str) -> torch.Tensor:
        ids = [vocab[t] for t in tokens]
        x = torch.tensor([ids], device=device)
        w = torch.tensor([[weapon_vocab[weapon_id]]], device=device)
        mask = x == pad_token_id
        with torch.no_grad():
            probs = torch.sigmoid(ultra(x, w, key_padding_mask=mask)).squeeze(0)
        return probs.detach().cpu()

    def predict_fn_ultra(current_tokens: list[str], weapon_id: str):
        probs = forward_probs(current_tokens, weapon_id).tolist()
        acts = (
            hook.last_h_post.detach().cpu().flatten()
            if hook.last_h_post is not None
            else None
        )
        return ({inv_vocab[i]: float(probs[i]) for i in range(vocab_size)}, acts)

    # Hook fidelity: compare bypassed outputs vs SAE reconstruction.
    context = [NULL]
    hook.set_mode(bypass=True)
    p_bypass = forward_probs(context, weapon_id)

    hook.set_mode(bypass=False, no_change=False)
    p_recon = forward_probs(context, weapon_id)
    p_delta = p_recon - p_bypass

    print("Hook fidelity (bypass vs recon):")
    print("prob max |Δ|:", float(p_delta.abs().max()))
    print("prob mean |Δ|:", float(p_delta.abs().mean()))
    if hook.last_in is not None and hook.last_x_recon is not None:
        x = hook.last_in.detach().cpu().flatten()
        x_recon = hook.last_x_recon.detach().cpu().flatten()
        print("x recon MSE:", float(F.mse_loss(x_recon, x)))
        print("x recon cosine:", float(F.cosine_similarity(x_recon, x, dim=0)))

    # Run a traced beam search so we can see activations.
    hook.set_mode(bypass=False, no_change=True)
    builds, traces = reconstruct_build(
        predict_fn=predict_fn_ultra,
        weapon_id=weapon_id,
        initial_context=[NULL],
        allocator=Allocator(),
        beam_size=5,
        max_steps=6,
        top_k=1,
        record_traces=True,
    )

    if not builds or not traces:
        raise RuntimeError("No build/trace produced")

    print("\nUltra build:")
    print(builds[0].to_dict())

    acts = traces[0][-1].activations
    if acts is None:
        raise RuntimeError("No SAE activations captured")

    topk = torch.topk(acts, k=10)
    print("\nTop active SAE features:")
    for i, v in zip(topk.indices.tolist(), topk.values.tolist()):
        fid = int(i)
        val = float(v)
        meta = feature_labels.get(fid)
        if not meta:
            print(fid, val)
            continue
        cat = meta.get("category", "none")
        name = (meta.get("name") or "").strip()
        if name:
            print(f"{fid:5d} {val:.4f} [{cat}] {name}")
        else:
            print(f"{fid:5d} {val:.4f} [{cat}]")
        if SHOW_LABEL_NOTES:
            notes = _compress_ws(meta.get("notes", ""))
            if notes:
                if len(notes) > MAX_NOTE_CHARS:
                    notes = notes[: MAX_NOTE_CHARS - 1] + "…"
                print("      notes:", notes)

    # Pick one feature and show what it most influences.
    feature_id = int(topk.indices[0].item())
    feature_value = float(acts[feature_id])
    print(f"\nExample feature: {feature_id} (value={feature_value:.4f})")
    if feature_labels.get(feature_id):
        meta = feature_labels[feature_id]
        name = (meta.get("name") or "").strip()
        cat = meta.get("category", "none")
        if name:
            print(f"label: [{cat}] {name}")
        if SHOW_LABEL_NOTES:
            notes = _compress_ws(meta.get("notes", ""))
            if notes:
                if len(notes) > MAX_NOTE_CHARS:
                    notes = notes[: MAX_NOTE_CHARS - 1] + "…"
                print("notes:", notes)

    decoder_norm = F.normalize(sae.decoder.weight, dim=0)
    influence = ultra.output_layer.weight @ decoder_norm[:, feature_id]
    influence = influence.detach().cpu().tolist()

    infl_pairs = []
    for i, val in enumerate(influence):
        tok = inv_vocab[i]
        if tok.startswith("<"):
            continue
        infl_pairs.append((tok, float(val)))
    infl_pairs.sort(key=lambda kv: kv[1], reverse=True)

    print("Top tokens by logit influence for this feature:")
    for tok, val in infl_pairs[:10]:
        print(f"{tok:<32} {val:+.4f}")

    handle.remove()
