# SplatNLP Colab Demo

This notebook downloads pretrained artifacts from DigitalOcean Spaces and runs:

- Token-space inference (multi-label set completion)
- Constraint-aware beam search to produce a legal build
- (Optional) Ultra + SAE hook introspection / steering


In [None]:
import os
import subprocess
import sys
from pathlib import Path

REPO_URL = "https://github.com/cesaregarza/SplatNLP.git"
REPO_DIR = Path("/content/SplatNLP")

if not REPO_DIR.exists():
    subprocess.run(
        ["git", "clone", "--depth", "1", REPO_URL, str(REPO_DIR)],
        check=True,
    )

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR / "src"))
print("Repo ready:", REPO_DIR)


In [None]:
!pip -q install requests tqdm


## Download artifacts (Full model)

This grabs `model.pth`, `model_params.json`, `vocab.json`, `weapon_vocab.json`, and `model_info.json`.

If you want the Ultra + SAE demo later in this notebook, set `DOWNLOAD_ULTRA_SAE = True`.


In [None]:
from pathlib import Path

from splatnlp.utils.download_artifacts import (
    CORE_ARTIFACTS,
    ULTRA_ARTIFACTS,
    ULTRA_SAE_ARTIFACTS,
    download_artifacts,
)

BASE_URL = "https://splat-nlp.nyc3.cdn.digitaloceanspaces.com"
DATASET_DIR = "dataset_v2"
OUT_DIR = Path("saved_models") / DATASET_DIR

DOWNLOAD_ULTRA_SAE = False

artifacts = list(CORE_ARTIFACTS)
if DOWNLOAD_ULTRA_SAE:
    artifacts += list(ULTRA_ARTIFACTS) + list(ULTRA_SAE_ARTIFACTS)

download_artifacts(
    base_url=BASE_URL,
    dataset_dir=DATASET_DIR,
    out_dir=OUT_DIR,
    artifacts=artifacts,
    force=False,
    timeout_s=180,
    quiet=False,
    dry_run=False,
)

print("Downloaded to:", OUT_DIR)
for p in sorted(OUT_DIR.iterdir()):
    print(" -", p.name)


## Load model + run token-space inference

This shows the top token predictions (not yet a legal build).


In [None]:
import json

import torch

from splatnlp.model.models import SetCompletionModel
from splatnlp.serve.tokenize import tokenize_build
from splatnlp.utils.constants import NULL, PAD

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

params = json.loads((OUT_DIR / "model_params.json").read_text())
vocab = json.loads((OUT_DIR / "vocab.json").read_text())
weapon_vocab = json.loads((OUT_DIR / "weapon_vocab.json").read_text())

pad_token_id = params.get("pad_token_id", vocab[PAD])
inv_vocab = {v: k for k, v in vocab.items()}

model = SetCompletionModel(**params)
model.load_state_dict(torch.load(OUT_DIR / "model.pth", map_location="cpu"))
model.to(device)
model.eval()

weapon_id = "weapon_id_8000" if "weapon_id_8000" in weapon_vocab else next(iter(weapon_vocab))
partial_build = {
    "ink_saver_main": 6,
    "run_speed_up": 12,
    "intensify_action": 10,
}

tokens = tokenize_build(partial_build)
print("weapon_id:", weapon_id)
print("input tokens:", tokens)

input_tokens = torch.tensor([[vocab[t] for t in tokens]], device=device)
input_weapons = torch.tensor([[weapon_vocab[weapon_id]]], device=device)
key_padding_mask = input_tokens == pad_token_id

with torch.no_grad():
    probs = torch.sigmoid(
        model(input_tokens, input_weapons, key_padding_mask=key_padding_mask)
    ).squeeze(0)

probs_cpu = probs.detach().cpu()
skip = {vocab.get(PAD), vocab.get(NULL)}
top = torch.topk(probs_cpu, k=min(20, probs_cpu.numel()))
for idx, p in zip(top.indices.tolist(), top.values.tolist()):
    if idx in skip:
        continue
    print(f"{inv_vocab[idx]:<32} {float(p):.4f}")


## Beam search: token-space → legal build-space

This uses the repo’s constraint-aware reconstruction (allocator + beam search) to produce a legal build.


In [None]:
from splatnlp.utils.reconstruct.allocator import Allocator
from splatnlp.utils.reconstruct.beam_search import reconstruct_build

allocator = Allocator()
vocab_size = len(vocab)

def predict_fn(current_tokens: list[str], weapon_id: str):
    ids = [vocab[t] for t in current_tokens]
    x = torch.tensor([ids], device=device)
    w = torch.tensor([[weapon_vocab[weapon_id]]], device=device)
    mask = x == pad_token_id
    with torch.no_grad():
        probs = torch.sigmoid(model(x, w, key_padding_mask=mask)).squeeze(0)
    probs = probs.detach().cpu().tolist()
    return {inv_vocab[i]: float(probs[i]) for i in range(vocab_size)}

builds = reconstruct_build(
    predict_fn=predict_fn,
    weapon_id=weapon_id,
    initial_context=[NULL],
    allocator=allocator,
    beam_size=5,
    max_steps=6,
    top_k=1,
)

if not builds:
    raise RuntimeError("No valid build produced")

build = builds[0]
build.to_dict()


## Optional: Ultra + SAE hook demo

This section is off by default. It loads `model_ultra.pth` and `sae_model_ultra.pth`, registers a hook on the 512D pooled representation, and shows:

- Top active SAE features for a run
- A tiny “steer” example by editing one active feature and rerunning beam search

To enable: set `DOWNLOAD_ULTRA_SAE = True` in the download cell above, then set `RUN_ULTRA_SAE = True` below.


In [None]:
RUN_ULTRA_SAE = False

if RUN_ULTRA_SAE:
    import json

    from splatnlp.monosemantic_sae.hooks import register_hooks
    from splatnlp.monosemantic_sae.models import SparseAutoencoder

    ultra = SetCompletionModel(**params)
    ultra.load_state_dict(
        torch.load(OUT_DIR / "model_ultra.pth", map_location="cpu")
    )
    ultra.to(device)
    ultra.eval()

    sae_cfg = json.loads((OUT_DIR / "sae_config_ultra.json").read_text())
    sae = SparseAutoencoder(
        input_dim=int(sae_cfg.get("input_dim", 512)),
        expansion_factor=float(sae_cfg.get("expansion_factor", 48.0)),
        l1_coefficient=float(sae_cfg.get("l1_coefficient", 0.0)),
        target_usage=float(sae_cfg.get("target_usage", 0.0)),
        usage_coeff=float(sae_cfg.get("usage_coeff", 0.0)),
        dead_neuron_threshold=float(sae_cfg.get("dead_neuron_threshold", 1e-6)),
        dead_neuron_steps=int(sae_cfg.get("dead_neuron_steps", 12500)),
    )
    sae.load_state_dict(
        torch.load(OUT_DIR / "sae_model_ultra.pth", map_location="cpu")
    )
    sae.to(device)
    sae.eval()

    hook, handle = register_hooks(ultra, sae_model=sae, bypass=False, no_change=True)

    def predict_fn_ultra(current_tokens: list[str], weapon_id: str):
        ids = [vocab[t] for t in current_tokens]
        x = torch.tensor([ids], device=device)
        w = torch.tensor([[weapon_vocab[weapon_id]]], device=device)
        mask = x == pad_token_id
        with torch.no_grad():
            probs = torch.sigmoid(ultra(x, w, key_padding_mask=mask)).squeeze(0)
        probs = probs.detach().cpu().tolist()
        acts = (
            hook.last_h_post.detach().cpu().flatten()
            if hook.last_h_post is not None
            else None
        )
        return ({inv_vocab[i]: float(probs[i]) for i in range(vocab_size)}, acts)

    builds, traces = reconstruct_build(
        predict_fn=predict_fn_ultra,
        weapon_id=weapon_id,
        initial_context=[NULL],
        allocator=Allocator(),
        beam_size=5,
        max_steps=6,
        top_k=1,
        record_traces=True,
    )

    if not builds or not traces:
        raise RuntimeError("No build/trace produced")

    print("Ultra build:")
    print(builds[0].to_dict())

    last = traces[0][-1]
    acts = last.activations
    if acts is None:
        raise RuntimeError("No SAE activations captured")

    topk = torch.topk(acts, k=10)
    print("\nTop active SAE features (id -> value):")
    for i, v in zip(topk.indices.tolist(), topk.values.tolist()):
        print(int(i), float(v))

    # Simple steer: ablate the most-active feature
    steer_feature = int(topk.indices[0].item())
    hook.update_neuron(steer_feature, 0.0)

    builds2 = reconstruct_build(
        predict_fn=predict_fn_ultra,
        weapon_id=weapon_id,
        initial_context=[NULL],
        allocator=Allocator(),
        beam_size=5,
        max_steps=6,
        top_k=1,
        record_traces=False,
    )

    print("\nSteered build (feature", steer_feature, "-> 0.0):")
    print(builds2[0].to_dict() if builds2 else None)

    handle.remove()
