# ASA HF Checkpoint + Canon Probes + Mini Fact Finetune

CPU-only Colab notebook: load the HF checkpoint, run canonical probes, generate samples, run a mini finetune (synthetic + optional WikiText mix), and re-run probes/generations.

**Expected runtime:** a few minutes on CPU for probes and a short finetune.

**Notes:**
- Uses the ASA/ASM generation helpers and mini-alignment loop from `building_blocks/working _example.py`.
- Saves JSON artifacts in `artifacts/` for quick inspection.


## Section 0 — Setup
This installs the repo and minimal dependencies, seeds the run, and creates an artifacts folder.

In [None]:
import os, sys, subprocess, platform, json, time, random
from pathlib import Path

repo_dir = 'ASA'
if not Path(repo_dir).exists():
    subprocess.run(['git','clone','https://github.com/digitaldaimyo/ASA.git'], check=True)
os.chdir(repo_dir)

subprocess.run([sys.executable,'-m','pip','install','-e','.'], check=True)
subprocess.run([sys.executable,'-m','pip','install','-q','huggingface_hub','safetensors','transformers','datasets'], check=True)

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

device = torch.device('cpu')
print('Python:', platform.python_version())
print('Torch:', torch.__version__)
print('Device:', device)
try:
    commit = subprocess.check_output(['git','rev-parse','HEAD']).decode().strip()
    print('Repo commit:', commit)
except Exception:
    print('Repo commit: unavailable')

seed = 1337
random.seed(seed)
torch.manual_seed(seed)

artifacts_dir = Path('artifacts')
artifacts_dir.mkdir(exist_ok=True)


## Section 1 — Load base model from Hugging Face (Baseline)
Loads the public checkpoint and verifies a forward pass.

In [None]:
from asa.load_pretrained import load_pretrained

HF_REPO = 'DigitalShogun/ASA-ASM-wikitext103-raw'
DEFAULT_CKPT = 'ASA_ASM_wt103-rawv1_gpt2_T1024_L21_D384_H8_K16_M32_ropek1_alibi1_gamma1_step75000_best.pt'

model, report, cfg_obj = load_pretrained(HF_REPO, DEFAULT_CKPT, variant='baseline', device='cpu')
print('Loaded model with vocab_size:', cfg_obj.vocab_size)
print('Checkpoint source:', report['state_dict_source'])
print('Allowlisted gaps:', {
    'missing': len(report['allowed_missing']),
    'unexpected': len(report['allowed_unexpected']),
    'mismatched': len(report['allowed_mismatched']),
})

input_ids = torch.randint(0, cfg_obj.vocab_size, (1, 32))
with torch.no_grad():
    logits, _ = model(input_ids)
print('Logits shape:', tuple(logits.shape))
assert logits.shape == (1, 32, cfg_obj.vocab_size)
assert torch.isfinite(logits).all()

run_metadata = {
    'repo': HF_REPO,
    'checkpoint': DEFAULT_CKPT,
    'state_dict_source': report['state_dict_source'],
    'seed': seed,
    'timestamp': time.time(),
    'config': cfg_obj.__dict__,
}
(artifacts_dir / 'run_metadata.json').write_text(json.dumps(run_metadata, indent=2))


## Section 2 — Canon Probes (BEFORE finetune)
Runs a small set of Paris/London margin probes and captures routing stats when available.

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True)

PROMPTS = [
    'The capital of France is',
    "France's capital city is",
    'Paris is the capital of',
    'The capital of the UK is',
    'London is the capital of',
    'A major city in France is',
]

def get_token_id(text):
    ids = tokenizer.encode(text)
    if len(ids) != 1:
        raise ValueError(f'Expected single token for {text}, got {ids}')
    return ids[0]

paris_id = get_token_id(' Paris')
london_id = get_token_id(' London')

def run_canon_probes(model, tag, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    margins = []
    top10 = []
    for prompt in PROMPTS:
        ids = tokenizer.encode(prompt)
        input_ids = torch.tensor([ids])
        with torch.no_grad():
            logits, infos = model(input_ids, return_info=True)
        last = logits[0, -1]
        margin = (last[paris_id] - last[london_id]).item()
        margins.append(margin)
        top_ids = torch.topk(last, k=10).indices.tolist()
        top10.append([tokenizer.decode([i]) for i in top_ids])
    mean_margin = float(sum(margins) / len(margins))
    min_margin = float(min(margins))

    routing_stats = {}
    try:
        sample = torch.randint(0, cfg_obj.vocab_size, (2, 16))
        with torch.no_grad():
            _, info = model(sample, return_info=True)
        if isinstance(info, list) and info:
            info0 = info[0] or {}
        else:
            info0 = info or {}
        if info0.get('read_weights') is not None:
            p = info0['read_weights'].float().clamp_min(1e-8)
            entropy = -(p * p.log()).sum(dim=-1).mean().item()
            top = p.argmax(dim=-1).reshape(-1)
            hist = torch.bincount(top, minlength=p.shape[-1]).float()
            top1freq = (hist.max() / hist.sum().clamp_min(1.0)).item()
            routing_stats['routing_entropy'] = entropy
            routing_stats['routing_top1freq'] = top1freq
        for key in ('content_read_gamma_mean','slotspace_gate_mean','slotspace_delta_norm'):
            if key in info0:
                routing_stats[key] = float(torch.as_tensor(info0[key]).mean().item())
    except Exception as exc:
        routing_stats['error'] = str(exc)

    results = {
        'tag': tag,
        'margins': margins,
        'mean_margin': mean_margin,
        'min_margin': min_margin,
        'top10_tokens': top10,
        'routing_stats': routing_stats,
    }
    out_path = out_dir / f'{tag}_probes.json'
    out_path.write_text(json.dumps(results, indent=2))
    print('Probe summary:', tag)
    print('  mean margin:', mean_margin)
    print('  min margin:', min_margin)
    return results

baseline_results = run_canon_probes(model, 'before_finetune', artifacts_dir)


## Section 3 — Generation helpers (ASA-aware)
Defines the ASA/ASM generation utilities from the working example, including router-aware resampling and repetition controls.

In [None]:
import math

from typing import Any, Dict, Optional, Tuple, Union, List


import torch

import torch.nn.functional as F


# Helpers

def _top_k_top_p_filtering(
    logits: torch.Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    min_tokens_to_keep: int = 1,
) -> torch.Tensor:
    """
    Filter a distribution of logits using top-k and/or nucleus (top-p).
    logits: [V]
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
        kth = torch.topk(logits, top_k).values[-1]
        logits = logits.masked_fill(logits < kth, float("-inf"))

    if top_p < 1.0:
        sorted_logits, sorted_idx = torch.sort(logits, descending=True)
        probs = F.softmax(sorted_logits, dim=-1)
        cumprobs = probs.cumsum(dim=-1)

        # Remove tokens with cumulative prob above threshold
        cutoff = cumprobs > top_p
        # Keep at least min_tokens_to_keep
        cutoff[:min_tokens_to_keep] = False

        sorted_logits = sorted_logits.masked_fill(cutoff, float("-inf"))
        logits = logits.scatter(0, sorted_idx, sorted_logits)

    return logits


def _apply_repetition_penalty(
    logits: torch.Tensor,
    generated_ids: torch.Tensor,
    penalty: float,
) -> torch.Tensor:
    """
    Classic repetition penalty (GPT-2 style): penalize logits of previously generated tokens.
    logits: [V], generated_ids: [t]
    """
    if penalty is None or penalty == 1.0 or generated_ids.numel() == 0:
        return logits
    uniq = torch.unique(generated_ids)
    # If logit > 0: divide by penalty; else multiply by penalty
    l = logits[uniq]
    logits[uniq] = torch.where(l > 0, l / penalty, l * penalty)
    return logits



def _no_repeat_ngram_ban(
    logits: torch.Tensor,
    generated_ids: torch.Tensor,
    no_repeat_ngram_size: int,
) -> torch.Tensor:
    """
    Ban tokens that would create a repeated n-gram of size N in the generated sequence.
    logits: [V], generated_ids: [t]
    """
    n = int(no_repeat_ngram_size or 0)
    if n <= 1 or generated_ids.numel() < n - 1:
        return logits

    seq = generated_ids.tolist()
    prefix = seq[-(n - 1):]  # length n-1
    # Build set of next tokens seen after this prefix in the past
    banned = set()
    for i in range(len(seq) - n + 1):
        if seq[i:i + n - 1] == prefix:
            banned.add(seq[i + n - 1])

    if banned:
        banned = torch.tensor(list(banned), device=logits.device, dtype=torch.long)
        logits[banned] = float("-inf")
    return logits


# -----------------------------------------------------------------------------
# ASA/ASM-specific generation
# -----------------------------------------------------------------------------
@torch.no_grad()
def asa_generate(
    prompt: Union[str, List[int], torch.Tensor],
    model: torch.nn.Module,
    gen: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Generation crafted for ASA/ASM models:
      - Uses soft sampling by default (hard routing variants are unstable per your ablations).
      - Optionally uses ASA internal telemetry (return_info=True) to perform *router-aware fallback*
        when EOS-risk is high early, or when routing is pathologically branchy.
      - Supports standard sampling controls + mild anti-repetition.
      - Keeps inference dropout off.

    Args
    ----
    prompt:
        - str: requires gen["tokenizer"] providing encode/decode
        - List[int] / 1D torch.Tensor: token ids
    model:
        ASMLanguageModel (or compatible) returning logits or (logits, infos) if return_info=True
    gen params (dict):
        Required (if prompt is str):
          tokenizer: a HF tokenizer with encode/decode
        Common:
          max_new_tokens: int (default 128)
          temperature: float (default 0.8)
          top_p: float (default 0.9)
          top_k: int (default 50)
          min_new_tokens: int (default 0)
          eos_token_id: int (default tokenizer.eos_token_id if available)
          pad_token_id: int (optional)
          do_sample: bool (default True)
          repetition_penalty: float (default 1.05)
          no_repeat_ngram_size: int (default 3)
          device: torch.device or str (default model device)
        ASA-aware controls:
          asa_info: bool (default True) -> request return_info and use it
          eos_risk_threshold: float (default 0.25)
          early_steps: int (default 24) -> window in which to apply EOS-risk mitigations
          branchy_entropy_threshold: float (default None) -> if set, triggers extra sharpening
          rescue_mode: str in {"none","scaffold","resample"} (default "resample")
              - "resample": if EOS risk triggers, resample with lower temp / higher top_k keep
              - "scaffold": if tokenizer provided and prompt looks like a known template,
                            inject a short scaffold (see below) once at the start
          rescue_temp: float (default 0.65)
          rescue_top_p: float (default 0.85)
          rescue_top_k: int (default 80)
          max_resample_tries: int (default 4)
        Return:
          return_text: bool (default True if tokenizer present else False)

    Returns
    -------
    dict with:
      "input_ids": [1, T+new]
      "generated_ids": [new]
      "text": optional
      "info_trace": optional list of per-step ASA stats (if asa_info=True)
    """
    model.eval()

    tokenizer = gen.get("tokenizer", None)
    device = gen.get("device", None)
    if device is None:
        device = next(model.parameters()).device

    # --- tokenize prompt ---
    if isinstance(prompt, str):
        if tokenizer is None:
            raise ValueError("prompt is str but gen['tokenizer'] was not provided.")
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    elif isinstance(prompt, list):
        input_ids = torch.tensor(prompt, device=device, dtype=torch.long).unsqueeze(0)
    elif isinstance(prompt, torch.Tensor):
        if prompt.dim() == 1:
            input_ids = prompt.to(device=device, dtype=torch.long).unsqueeze(0)
        elif prompt.dim() == 2:
            input_ids = prompt.to(device=device, dtype=torch.long)
        else:
            raise ValueError("prompt tensor must be 1D or 2D token ids.")
    else:
        raise TypeError("prompt must be str, List[int], or torch.Tensor of token ids.")

    max_new = int(gen.get("max_new_tokens", 128))
    min_new = int(gen.get("min_new_tokens", 0))
    do_sample = bool(gen.get("do_sample", True))

    temperature = float(gen.get("temperature", 0.8))
    top_p = float(gen.get("top_p", 0.9))
    top_k = int(gen.get("top_k", 50))

    repetition_penalty = float(gen.get("repetition_penalty", 1.05))
    no_repeat_ngram_size = int(gen.get("no_repeat_ngram_size", 3))

    eos_token_id = gen.get("eos_token_id", None)
    if eos_token_id is None and tokenizer is not None:
        eos_token_id = tokenizer.eos_token_id
    if eos_token_id is None:
        eos_token_id = -1  # disable EOS logic if unknown

    asa_info = bool(gen.get("asa_info", True))
    eos_risk_threshold = float(gen.get("eos_risk_threshold", 0.25))
    early_steps = int(gen.get("early_steps", 24))
    branchy_entropy_threshold = gen.get("branchy_entropy_threshold", None)
    rescue_mode = str(gen.get("rescue_mode", "resample")).lower()
    rescue_temp = float(gen.get("rescue_temp", 0.65))
    rescue_top_p = float(gen.get("rescue_top_p", 0.85))
    rescue_top_k = int(gen.get("rescue_top_k", 80))
    max_resample_tries = int(gen.get("max_resample_tries", 4))

    # Optional scaffold injection (architecture-aware: helps route trajectory)
    if rescue_mode == "scaffold" and tokenizer is not None and isinstance(prompt, str):
        # Very small, conservative scaffold set—extend as you like
        scaffolds = [
            ("The capital of", " the city of"),
            ("Albert Einstein was born", " in"),
            ("The scientific method involves", " the process of"),
            ("The algorithm proceeds as follows", " 1."),
        ]
        for k, s in scaffolds:
            if prompt.strip().startswith(k) and not prompt.strip().endswith(s.strip()):
                input_ids = tokenizer.encode(prompt + s, return_tensors="pt").to(device)
                break

    info_trace: List[Dict[str, float]] = []

    # Generation loop
    cur_ids = input_ids
    for step in range(max_new):
        # Model forward
        if asa_info:
            out = model(cur_ids, return_info=True)
            logits, infos = out
            # infos is list per layer; take last block's light stats if present
            last = infos[-1] if isinstance(infos, list) and len(infos) > 0 else None
            stat = {}
            if isinstance(last, dict):
                # these are CPU tensors in your module; cast to float if present
                for k in ["entropy_mean", "top1freq_mean", "content_read_gamma_mean", "slotspace_gate_mean", "slotspace_delta_norm"]:
                    if k in last and last[k] is not None:
                        try:
                            stat[k] = float(last[k].item())
                        except Exception:
                            pass
            # Store later for debugging
        else:
            logits = model(cur_ids, return_info=False)
            stat = None

        next_logits = logits[0, -1, :]  # [V]

        # Basic constraints
        if step < min_new and eos_token_id >= 0:
            next_logits = next_logits.clone()
            next_logits[eos_token_id] = float("-inf")

        # Anti-repetition (mild, usually good for ASA because content-read is self-referential)
        gen_so_far = cur_ids[0, input_ids.shape[1]:]  # only newly generated, if any
        next_logits = _apply_repetition_penalty(next_logits, gen_so_far, repetition_penalty)
        next_logits = _no_repeat_ngram_ban(next_logits, cur_ids[0], no_repeat_ngram_size)

        # Router-aware rescue (early EOS / excessive branchiness)
        # Use next-token EOS risk; optionally sharpen if branchy.
        tries = 0
        used_temp, used_top_p, used_top_k = temperature, top_p, top_k
        while True:
            l = next_logits
            if used_temp and used_temp > 0:
                l = l / used_temp

            l = _top_k_top_p_filtering(l, top_k=used_top_k, top_p=used_top_p)

            probs = F.softmax(l, dim=-1)
            p_eos = float(probs[eos_token_id].item()) if eos_token_id >= 0 else 0.0
            ent = float(-(probs.clamp_min(1e-12) * probs.clamp_min(1e-12).log()).sum().item())

            # Condition: early EOS risk is too high
            eos_risky = (eos_token_id >= 0) and (step < early_steps) and (p_eos > eos_risk_threshold)

            # Condition: branchy token distribution (optional) -> reduce temperature a bit
            branchy = False
            if branchy_entropy_threshold is not None and step < early_steps:
                branchy = ent > float(branchy_entropy_threshold)

            if (eos_risky or branchy) and rescue_mode == "resample" and tries < max_resample_tries:
                used_temp = min(used_temp, rescue_temp)
                used_top_p = min(used_top_p, rescue_top_p)
                used_top_k = max(used_top_k, rescue_top_k)
                tries += 1
                continue

            # Choose token
            if do_sample:
                next_id = torch.multinomial(probs, num_samples=1)
            else:
                next_id = torch.argmax(probs, dim=-1, keepdim=True)

            break

        # Log trace
        if asa_info:
            rec = {"step": float(step), "token_entropy": float(ent), "p_eos": float(p_eos)}
            if stat:
                for k, v in stat.items():
                    rec[k] = float(v)
            # record rescue adjustments
            rec["temp_used"] = float(used_temp)
            rec["top_p_used"] = float(used_top_p)
            rec["top_k_used"] = float(used_top_k)
            info_trace.append(rec)

        # Append token
        cur_ids = torch.cat([cur_ids, next_id.view(1, 1)], dim=1)

        # Stop on EOS
        if eos_token_id >= 0 and int(next_id.item()) == int(eos_token_id) and step >= min_new:
            break

    generated_ids = cur_ids[:, input_ids.shape[1]:]

    out: Dict[str, Any] = {
        "input_ids": cur_ids,
        "generated_ids": generated_ids,
    }
    if asa_info:
        out["info_trace"] = info_trace

    return_text = bool(gen.get("return_text", tokenizer is not None))
    if return_text and tokenizer is not None:
        out["text"] = tokenizer.decode(cur_ids[0].tolist(), skip_special_tokens=False)

    return out


# =========================
# PATCH 1: wrappers for your crafted asa_generate
# =========================

@torch.no_grad()
def asa_greedy_suffix(
    prompt: str,
    model: torch.nn.Module,
    gen: dict,
    max_new_tokens: int = 8,
    strip: bool = True,
) -> str:
    """
    Runs your asa_generate in greedy mode and returns ONLY the suffix after `prompt`.
    This is what you want for exact-match checks / scoring.
    """
    # Copy gen so we can override safely
    g = dict(gen)
    g["do_sample"] = False
    g["max_new_tokens"] = int(max_new_tokens)

    out = asa_generate(prompt, model, g)
    text = out.get("text", None)
    if text is None:
        # Fallback: decode manually
        tok = g.get("tokenizer", None)
        if tok is None:
            raise ValueError("No decoded text available; provide gen['tokenizer'].")
        text = tok.decode(out["input_ids"][0].tolist(), skip_special_tokens=False)

    # Suffix (best-effort): if prompt string matches prefix of decoded text
    if text.startswith(prompt):
        suf = text[len(prompt):]
    else:
        # Robust fallback: try to locate the prompt inside the decoded text
        idx = text.find(prompt)
        suf = text[idx + len(prompt):] if idx >= 0 else text

    if strip:
        suf = suf.replace("\n", " ").strip()
    return suf


@torch.no_grad()
def asa_generate_many(
    prompts: list,
    model: torch.nn.Module,
    gen: dict,
    do_sample: bool = False,
    max_new_tokens: int = 8,
) -> list:
    """
    Convenience wrapper: runs asa_generate per prompt (loop) and returns decoded texts.
    """
    g = dict(gen)
    g["do_sample"] = bool(do_sample)
    g["max_new_tokens"] = int(max_new_tokens)

    outs = []
    for p in prompts:
        out = asa_generate(p, model, g)
        text = out.get("text", None)
        if text is None:
            tok = g.get("tokenizer", None)
            if tok is None:
                raise ValueError("No decoded text available; provide gen['tokenizer'].")
            text = tok.decode(out["input_ids"][0].tolist(), skip_special_tokens=False)
        outs.append(text)
    return outs

@torch.no_grad()
def score_next_token_rank(
    prompt: str,
    target_token: str,
    model: torch.nn.Module,
    gen: dict,
) -> dict:
    """
    Computes P(target) and rank for the *next token* only, matching your printed diagnostics.
    """
    tok = gen["tokenizer"]
    device = next(model.parameters()).device

    # encode prompt
    input_ids = tok.encode(prompt, return_tensors="pt").to(device)

    # encode target token as a single token (best-effort)
    target_ids = tok.encode(target_token, add_special_tokens=False)
    if len(target_ids) != 1:
        return {"ok": False, "reason": f"target_token maps to {len(target_ids)} tokens", "target_ids": target_ids}

    target_id = target_ids[0]

    model.eval()
    logits = model(input_ids)  # if your model needs return_info=False default
    if isinstance(logits, (tuple, list)):
        logits = logits[0]
    next_logits = logits[0, -1, :]

    probs = torch.softmax(next_logits, dim=-1)
    p_t = float(probs[target_id].item())

    # rank: 1 = best
    sorted_idx = torch.argsort(next_logits, descending=True)
    rank = int((sorted_idx == target_id).nonzero(as_tuple=False).item()) + 1

    return {"ok": True, "p_target": p_t, "rank": rank, "target_id": target_id}

#



#@title multigen
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)

gener = dict(
    tokenizer=tokenizer,
    max_new_tokens=32,
    #min_new_tokens=4,
    temperature=0.1,
    top_p=0.95,
    top_k=80,
    repetition_penalty=1.03,
    no_repeat_ngram_size=3, # 3
    asa_info=False,
    rescue_mode=None, # "resample", None
    #eos_risk_threshold=0.25,
    #early_steps=24,
    #branchy_entropy_threshold=7.5,   # optional; depends on vocab size and filtering
)

print("#"*5, "Countries", "#"*5)
finishers = ["is", "sounds like", "consists of", "is a form of", "all changed when"]
qualities = ["capital", "language", "geography", "government", "history"]
countries = ["France", "Spain", "Russia", "Italy", "Japan", "Egypt", "Germany", "Brazil"]
for country in countries:
    for quality, finisher in zip(qualities, finishers):
        out = asa_generate(f"The {quality} of {country} {finisher}", model, gener)
        print(out["text"])

print("#"*5, "People", "#"*5)
people = ["Albert Einstein", "George Patton", "Charles Darwin", "George Washington", "Winston Churchill"]
factoids = ["was born", "contributed", "accomplished", "had a strong opinion about", "died"]
for person in people:
    for factoid in factoids:
        out = asa_generate(f"{person} {factoid}", model, gener)
        print(out["text"])


# Optionally inspect router-aware trace:
# out["info_trace"][:5]


#@title Prepare Generator

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)

gener = dict(
    tokenizer=tokenizer,
    max_new_tokens=32,
    #min_new_tokens=4,
    temperature=0.1,
    top_p=0.95,
    top_k=80,
    repetition_penalty=1.03,
    no_repeat_ngram_size=3, # 3
    asa_info=False,
    rescue_mode=None, # "resample", None
    #eos_risk_threshold=0.25,
    #early_steps=24,
    #branchy_entropy_threshold=7.5,   # optional; depends on vocab size and filtering
)



# ==========================================
#@title Expanded Mini-alignment dataset + WikiText mix + optional slot-attn-only finetune + rerun generations
# (Aligned to ASMLanguageModel + your asa_generate)
# ==========================================

import random
import math
import re
import itertools
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

# -----------------------
# 0) Repro & device
# -----------------------
SEED = 1337
random.seed(SEED)
torch.manual_seed(SEED)
device = next(model.parameters()).device


from datasets import load_dataset
# Use a community-hosted mirror
#dataset = load_dataset('segyges/wikitext-103', name='wikitext-103-raw-v1')


# -----------------------
# 0.5) Config knobs (NEW)
# -----------------------
CFG = {
    # mix in WikiText
    "use_wiki": False,
    "wiki_dataset_name": "wikitext",
    "wiki_config_candidates": ["wikitext-103-raw-v1", "wikitext-2-raw-v1"],  # fallback
    "wiki_num_samples": 1536,         # number of wiki chunks (not lines)
    "wiki_chunk_chars_min": 400,      # filter small chunks
    "wiki_chunk_chars_max": 1200,     # chunk size (chars) before tokenization

    # training
    "max_len": 128,                   # increased since wiki chunks are longer
    "batch_size": 16,
    "steps": 77,
    "lr": 7e-6,
    "weight_decay": 0.007,
    "grad_clip": 1.0,

    # finetune mode
    # "all" trains everything; "slot_attn_only" freezes everything except slot-space attention op
    "finetune_mode": "slot_attn_only",  # or "all" or  "slot_attn_only"

    # which params count as "slot attention" (adjust to your module names)
    #"slot_train_name_regex": r"(slot|slots).*(attn|attention)|((attn|attention).*(slot|slots))",

    "slot_train_name_regex":r"(^|\.)(slot_in|slot_q|slot_k|slot_v|slot_out)\.weight$|(^|\.)(_slotspace_gate_raw)$",


}

# -----------------------
# 1) Utilities (aligned to your model call style)
# -----------------------
@torch.no_grad()
def next_token_stats(prompt: str, target_token_str: str, model, tokenizer):
    model.eval()
    inp = tokenizer.encode(prompt, return_tensors="pt").to(device)

    tgt_ids = tokenizer.encode(target_token_str, add_special_tokens=False)
    if len(tgt_ids) != 1:
        return {"ok": False, "reason": f"target string maps to {len(tgt_ids)} tokens", "target_ids": tgt_ids}

    tgt = tgt_ids[0]

    out = model(inp)
    logits = out[0] if isinstance(out, (tuple, list)) else out
    last = logits[0, -1, :]
    probs = torch.softmax(last, dim=-1)

    p = float(probs[tgt].item())
    rank = int((torch.argsort(last, descending=True) == tgt).nonzero(as_tuple=False).item()) + 1
    top1_id = int(torch.argmax(last).item())
    top1 = tokenizer.decode([top1_id])

    return {"ok": True, "p_target": p, "rank": rank, "top1": top1, "target_id": tgt}

@torch.no_grad()
def greedy_suffix(prompt: str, model, gen, max_new_tokens=8):
    g = dict(gen)
    g["do_sample"] = False
    g["max_new_tokens"] = int(max_new_tokens)
    out = asa_generate(prompt, model, g)
    text = out["text"]
    if text.startswith(prompt):
        return text[len(prompt):].replace("\n", " ").strip()
    idx = text.find(prompt)
    if idx >= 0:
        return text[idx+len(prompt):].replace("\n", " ").strip()
    return text.replace("\n", " ").strip()

@torch.no_grad()
def eval_exact_match(examples, model, gen, max_new_tokens=8):
    model.eval()
    ok = 0
    for ex in examples:
        pred = greedy_suffix(ex["prompt"], model, gen, max_new_tokens=max_new_tokens)
        gold = ex["completion"].replace("\n", " ").strip()
        ok += int(pred.startswith(gold))
    return ok / max(1, len(examples))

# -----------------------
# 2) Dataset builders
# -----------------------


def load_wikitext_chunks(tokenizer, num_samples=2048, chunk_chars_min=400, chunk_chars_max=1200):
    """
    Produces wiki training examples as plain LM text chunks:
      ex = {"prompt": "", "completion": "<wiki chunk>", "tag": "wiki"}
    We chunk by chars first, then token-truncate later in dataset.
    """
    try:
        from datasets import load_dataset
    except Exception as e:
        print("[WikiText] datasets not available; skipping WikiText mix.")
        return []

    ds = None
    used_cfg = None
    for cfg in CFG["wiki_config_candidates"]:
        try:
            ds = load_dataset(CFG["wiki_dataset_name"], cfg, split="train")
            used_cfg = cfg
            break
        except Exception:
            ds = None

    if ds is None:
        print("[WikiText] Could not load WikiText (tried configs:", CFG["wiki_config_candidates"], "). Skipping.")
        return []

    print(f"[WikiText] Loaded {CFG['wiki_dataset_name']} / {used_cfg} train split with {len(ds)} rows.")

    # Pull raw text field (wikitext uses 'text')
    texts = [t for t in ds["text"] if isinstance(t, str) and len(t.strip()) > 0]

    # Make chunks: concatenate consecutive lines until size bound, filter small chunks
    chunks = []
    buf = []
    buf_len = 0

    # shuffle deterministically
    rng = random.Random(SEED)
    rng.shuffle(texts)

    for line in texts:
        line = line.strip()
        # skip headings markup lines; keep normal prose
        if line.startswith("=") and line.endswith("="):
            continue
        if not line:
            continue

        # add line to buffer
        if buf_len + len(line) + 1 <= chunk_chars_max:
            buf.append(line)
            buf_len += len(line) + 1
        else:
            chunk = " ".join(buf).strip()
            if len(chunk) >= chunk_chars_min:
                chunks.append(chunk)
            buf = [line]
            buf_len = len(line) + 1

        if len(chunks) >= num_samples:
            break

    # flush
    if len(chunks) < num_samples:
        chunk = " ".join(buf).strip()
        if len(chunk) >= chunk_chars_min:
            chunks.append(chunk)

    # create examples
    wiki_examples = [{"prompt": "", "completion": c, "tag": "wiki"} for c in chunks[:num_samples]]
    print(f"[WikiText] Prepared {len(wiki_examples)} wiki chunks.")

    return wiki_examples

# -----------------------
# 3) Split synthetic (entity-holdout) + build mixed train set
# -----------------------
#pairs = build_pairs_expanded(tokenizer)

from collections import Counter

holdout_capitals = {
    "Spain", "Canada", "Poland", "Portugal", "Greece", "Austria",
    "Norway", "Ireland", "Romania", "Croatia", "Argentina", "Chile"
}
holdout_languages = {"Brazil", "Mexico", "Netherlands", "Sweden", "Finland"}
holdout_currencies = {"Japan", "Switzerland", "South Africa", "Thailand"}
holdout_continents = {"Kenya", "Vietnam", "Peru", "New Zealand"}

train_examples, holdout_examples = [], []
for ex in pairs:
    task = ex["tag"].split(":")[0]

    if task == "capital":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_capitals else train_examples).append(ex)
    elif task == "language":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_languages else train_examples).append(ex)
    elif task == "currency":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_currencies else train_examples).append(ex)
    elif task == "continent":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_continents else train_examples).append(ex)
    else:
        if random.random() < 0.2:
            holdout_examples.append(ex)
        else:
            train_examples.append(ex)

print(f"\n[Synthetic] Total kept pairs: {len(pairs)} | Train: {len(train_examples)} | Holdout: {len(holdout_examples)}")
print(f"[Synthetic] Split: {len(train_examples)/len(pairs)*100:.1f}% / {len(holdout_examples)/len(pairs)*100:.1f}%")
holdout_by_cat = Counter(ex["tag"].split(":")[0] for ex in holdout_examples)
print("[Synthetic] Holdout by category:", dict(holdout_by_cat))

# NEW: load wiki and mix into TRAIN ONLY
wiki_examples = []
if CFG["use_wiki"]:
    wiki_examples = load_wikitext_chunks(
        tokenizer,
        num_samples=CFG["wiki_num_samples"],
        chunk_chars_min=CFG["wiki_chunk_chars_min"],
        chunk_chars_max=CFG["wiki_chunk_chars_max"],
    )

mixed_train_examples = train_examples + wiki_examples
print(f"\n[Mix] Train synthetic={len(train_examples)} + wiki={len(wiki_examples)} => mixed_train={len(mixed_train_examples)}")
print(f"[Mix] Holdout (synthetic only) = {len(holdout_examples)}")

# -----------------------
# 4) Tiny finetune dataset (teacher forcing)
#    Works for BOTH: prompt+completion pairs and raw wiki chunks (prompt="")
# -----------------------
class PromptCompletionDataset(Dataset):
    def __init__(self, examples, tokenizer, max_len=128):
        self.examples = examples
        self.tok = tokenizer
        self.max_len = int(max_len)

    def __len__(self): return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        text = ex["prompt"] + ex["completion"]
        ids = self.tok.encode(text)

        # keep the tail; for wiki, this acts like "random suffix LM"
        ids = ids[-self.max_len:]

        x = torch.tensor(ids[:-1], dtype=torch.long)
        y = torch.tensor(ids[1:], dtype=torch.long)
        return x, y, ex

def collate_pad(batch):
    xs, ys, exs = zip(*batch)
    maxT = max(x.size(0) for x in xs)
    pad_id = tokenizer.eos_token_id  # GPT-2 no pad token

    X = torch.full((len(xs), maxT), pad_id, dtype=torch.long)
    Y = torch.full((len(xs), maxT), -100, dtype=torch.long)

    for i, (x, y) in enumerate(zip(xs, ys)):
        T = x.size(0)
        X[i, :T] = x
        Y[i, :T] = y
    return X.to(device), Y.to(device), exs

train_ds = PromptCompletionDataset(mixed_train_examples, tokenizer, max_len=CFG["max_len"])
train_dl = DataLoader(
    train_ds,
    batch_size=min(CFG["batch_size"], len(train_ds)),
    shuffle=True,
    collate_fn=collate_pad
)

# -----------------------
# 5) Optional: freeze everything except slot-space attention (NEW)
# -----------------------
def configure_finetune_mode(model, mode: str, name_regex: str):
    """
    mode:
      - "all": train everything
      - "slot_attn_only": only train parameters whose full name matches `name_regex`
    """
    if mode == "all":
        for p in model.parameters():
            p.requires_grad = True
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"[Finetune] mode=all trainable={trainable}/{total} ({trainable/total*100:.2f}%)")
        return

    if mode != "slot_attn_only":
        raise ValueError(f"Unknown finetune_mode={mode}")

    rx = re.compile(name_regex, flags=re.IGNORECASE)

    # freeze everything
    for _, p in model.named_parameters():
        p.requires_grad = False

    # unfreeze matching params
    matched = []
    for n, p in model.named_parameters():
        if rx.search(n) is not None:
            p.requires_grad = True
            matched.append((n, p.numel()))

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"[Finetune] mode=slot_attn_only regex={name_regex!r}")
    print(f"[Finetune] trainable={trainable}/{total} ({trainable/total*100:.4f}%) matched_tensors={len(matched)}")

    # show top matches by size
    matched.sort(key=lambda x: -x[1])
    for n, k in matched[:25]:
        print(f"  [trainable] {k:>10}  {n}")

configure_finetune_mode(model, CFG["finetune_mode"], CFG["slot_train_name_regex"])

# -----------------------
# 6) Pre-eval (synthetic only, as before)
# -----------------------
print("\n" + "="*80)
print("PRE-TRAINING EVALUATION (synthetic only)")
print("="*80)

# FIX: Enable asa_info to handle model's tuple return type correctly
gener['asa_info'] = True

pre_acc_train = eval_exact_match(train_examples, model, gener, max_new_tokens=8)
pre_acc_hold  = eval_exact_match(holdout_examples, model, gener, max_new_tokens=8)

print(f"\n[PRE] Exact-match accuracy:")
print(f"  Train:   {pre_acc_train:.3f} ({int(pre_acc_train*len(train_examples))}/{len(train_examples)})")
print(f"  Holdout: {pre_acc_hold:.3f} ({int(pre_acc_hold*len(holdout_examples))}/{len(holdout_examples)})")

print("\n[PRE] Next-token stats for sample of single-token targets (synthetic only):")
sample_for_stats = random.sample(pairs, min(30, len(pairs)))
for ex in sample_for_stats:
    stats = next_token_stats(ex["prompt"], ex["completion"], model, tokenizer)
    if stats["ok"]:
        print(f"  {ex['tag']:<25} P={stats['p_target']:.4f} rank={stats['rank']:>5} top1={stats['top1']!r}")
    else:
        print(f"  {ex['tag']:<25} (skip) {stats['reason']}")

# -----------------------
# 7) Light training (mixed: synthetic + wiki)
# -----------------------
print("\n" + "="*80)
print("TRAINING (mixed synthetic + wiki)")
print("="*80)

model.train()

# IMPORTANT: optimizer must only see trainable params (esp for slot_attn_only)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if len(trainable_params) == 0:
    raise RuntimeError("No trainable parameters. Check CFG['finetune_mode'] and regex.")

opt = torch.optim.AdamW(
    trainable_params,
    lr=CFG["lr"],
    betas=(0.9, 0.95),
    weight_decay=CFG["weight_decay"]
)

steps = int(CFG["steps"])
grad_clip = float(CFG["grad_clip"])

print(f"Training for {steps} steps with batch_size={train_dl.batch_size}")
print(f"Total mixed_train_examples={len(mixed_train_examples)} | synthetic={len(train_examples)} | wiki={len(wiki_examples)}\n")

# stable batch stream (avoid re-instantiating iter(train_dl) each step)
batch_iter = itertools.cycle(train_dl)

for step in range(steps):
    X, Y, _ = next(batch_iter)
    opt.zero_grad(set_to_none=True)

    logits = model(X)
    logits = logits[0] if isinstance(logits, (tuple, list)) else logits

    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        Y.view(-1),
        ignore_index=-100
    )
    loss.backward()
    torch.nn.utils.clip_grad_norm_(trainable_params, grad_clip)
    opt.step()

    if (step + 1) % 50 == 0:
        print(f"  [train] step {step+1:>4}/{steps} loss={float(loss.item()):.4f}")

model.eval()

# -----------------------
# 8) Post-eval (synthetic only, as before)
# -----------------------
print("\n" + "="*80)
print("POST-TRAINING EVALUATION (synthetic only)")
print("="*80)

post_acc_train = eval_exact_match(train_examples, model, gener, max_new_tokens=8)
post_acc_hold  = eval_exact_match(holdout_examples, model, gener, max_new_tokens=8)

print(f"\n[POST] Exact-match accuracy:")
print(f"  Train:   {post_acc_train:.3f} ({int(post_acc_train*len(train_examples))}/{len(train_examples)})")
print(f"  Holdout: {post_acc_hold:.3f} ({int(post_acc_hold*len(holdout_examples))}/{len(holdout_examples)})")

print(f"\n[DELTA] Accuracy change:")
print(f"  Train:   {pre_acc_train:.3f} -> {post_acc_train:.3f} (Δ={post_acc_train-pre_acc_train:+.3f})")
print(f"  Holdout: {pre_acc_hold:.3f} -> {post_acc_hold:.3f} (Δ={post_acc_hold-pre_acc_hold:+.3f})")

print("\n[POST] Next-token stats for same sample (synthetic only):")
for ex in sample_for_stats:
    stats = next_token_stats(ex["prompt"], ex["completion"], model, tokenizer)
    if stats["ok"]:
        print(f"  {ex['tag']:<25} P={stats['p_target']:.4f} rank={stats['rank']:>5} top1={stats['top1']!r}")

# -----------------------
# 9) Generations (synthetic categories only)
# -----------------------
print("\n" + "="*80)
print("GENERATION SAMPLES (greedy decoding) (synthetic only)")
print("="*80)

generation_samples = []
by_category = {}
for ex in pairs:
    cat = ex["tag"].split(":")[0]
    by_category.setdefault(cat, []).append(ex)

for cat, exs in sorted(by_category.items()):
    generation_samples.extend(exs[:2])

generation_samples = generation_samples[:25]

for ex in generation_samples:
    raw = greedy_suffix(ex["prompt"], model, gener, max_new_tokens=12)

    tag_base = ex["tag"].split(":")[0]
    if tag_base == "capital":
        scaffold_prompt = ex["prompt"] + " the city of"
    elif tag_base == "language":
        scaffold_prompt = ex["prompt"] + " primarily"
    elif tag_base == "currency":
        scaffold_prompt = ex["prompt"]
    else:
        scaffold_prompt = ex["prompt"]

    sca = greedy_suffix(scaffold_prompt, model, gener, max_new_tokens=12)

    print(f"\n{'─'*80}")
    print(f"CATEGORY: {ex['tag']:<25} TARGET: {ex['completion']!r}")
    print(f"PROMPT:   {ex['prompt']!r}")
    print(f"RAW:      {raw[:100]}")
    if scaffold_prompt != ex["prompt"]:
        print(f"SCAFFOLD: {sca[:100]}")

print("\n" + "="*80)
print("EVALUATION COMPLETE")
print("="*80)


# ASA/ASM-specific generation

def asa_generate(
    prompt: Union[str, List[int], torch.Tensor],
    model: torch.nn.Module,
    gen: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Generation crafted for ASA/ASM models:
      - Uses soft sampling by default (hard routing variants are unstable per your ablations).
      - Optionally uses ASA internal telemetry (return_info=True) to perform *router-aware fallback*
        when EOS-risk is high early, or when routing is pathologically branchy.
      - Supports standard sampling controls + mild anti-repetition.
      - Keeps inference dropout off.

    Args
    ----
    prompt:
        - str: requires gen["tokenizer"] providing encode/decode
        - List[int] / 1D torch.Tensor: token ids
    model:
        ASMLanguageModel (or compatible) returning logits or (logits, infos) if return_info=True
    gen params (dict):
        Required (if prompt is str):
          tokenizer: a HF tokenizer with encode/decode
        Common:
          max_new_tokens: int (default 128)
          temperature: float (default 0.8)
          top_p: float (default 0.9)
          top_k: int (default 50)
          min_new_tokens: int (default 0)
          eos_token_id: int (default tokenizer.eos_token_id if available)
          pad_token_id: int (optional)
          do_sample: bool (default True)
          repetition_penalty: float (default 1.05)
          no_repeat_ngram_size: int (default 3)
          device: torch.device or str (default model device)
        ASA-aware controls:
          asa_info: bool (default True) -> request return_info and use it
          eos_risk_threshold: float (default 0.25)
          early_steps: int (default 24) -> window in which to apply EOS-risk mitigations
          branchy_entropy_threshold: float (default None) -> if set, triggers extra sharpening
          rescue_mode: str in {"none","scaffold","resample"} (default "resample")
              - "resample": if EOS risk triggers, resample with lower temp / higher top_k keep
              - "scaffold": if tokenizer provided and prompt looks like a known template,
                            inject a short scaffold (see below) once at the start
          rescue_temp: float (default 0.65)
          rescue_top_p: float (default 0.85)
          rescue_top_k: int (default 80)
          max_resample_tries: int (default 4)
        Return:
          return_text: bool (default True if tokenizer present else False)

    Returns
    -------
    dict with:
      "input_ids": [1, T+new]
      "generated_ids": [new]
      "text": optional
      "info_trace": optional list of per-step ASA stats (if asa_info=True)
    """
    model.eval()

    tokenizer = gen.get("tokenizer", None)
    device = gen.get("device", None)
    if device is None:
        device = next(model.parameters()).device

    # --- tokenize prompt ---
    if isinstance(prompt, str):
        if tokenizer is None:
            raise ValueError("prompt is str but gen['tokenizer'] was not provided.")
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    elif isinstance(prompt, list):
        input_ids = torch.tensor(prompt, device=device, dtype=torch.long).unsqueeze(0)
    elif isinstance(prompt, torch.Tensor):
        if prompt.dim() == 1:
            input_ids = prompt.to(device=device, dtype=torch.long).unsqueeze(0)
        elif prompt.dim() == 2:
            input_ids = prompt.to(device=device, dtype=torch.long)
        else:
            raise ValueError("prompt tensor must be 1D or 2D token ids.")
    else:
        raise TypeError("prompt must be str, List[int], or torch.Tensor of token ids.")

    max_new = int(gen.get("max_new_tokens", 128))
    min_new = int(gen.get("min_new_tokens", 0))
    do_sample = bool(gen.get("do_sample", True))

    temperature = float(gen.get("temperature", 0.8))
    top_p = float(gen.get("top_p", 0.9))
    top_k = int(gen.get("top_k", 50))

    repetition_penalty = float(gen.get("repetition_penalty", 1.05))
    no_repeat_ngram_size = int(gen.get("no_repeat_ngram_size", 3))

    eos_token_id = gen.get("eos_token_id", None)
    if eos_token_id is None and tokenizer is not None:
        eos_token_id = tokenizer.eos_token_id
    if eos_token_id is None:
        eos_token_id = -1  # disable EOS logic if unknown

    asa_info = bool(gen.get("asa_info", True))
    eos_risk_threshold = float(gen.get("eos_risk_threshold", 0.25))
    early_steps = int(gen.get("early_steps", 24))
    branchy_entropy_threshold = gen.get("branchy_entropy_threshold", None)
    rescue_mode = str(gen.get("rescue_mode", "resample")).lower()
    rescue_temp = float(gen.get("rescue_temp", 0.65))
    rescue_top_p = float(gen.get("rescue_top_p", 0.85))
    rescue_top_k = int(gen.get("rescue_top_k", 80))
    max_resample_tries = int(gen.get("max_resample_tries", 4))

    # Optional scaffold injection (architecture-aware: helps route trajectory)
    if rescue_mode == "scaffold" and tokenizer is not None and isinstance(prompt, str):
        # Very small, conservative scaffold set—extend as you like
        scaffolds = [
            ("The capital of", " the city of"),
            ("Albert Einstein was born", " in"),
            ("The scientific method involves", " the process of"),
            ("The algorithm proceeds as follows", " 1."),
        ]
        for k, s in scaffolds:
            if prompt.strip().startswith(k) and not prompt.strip().endswith(s.strip()):
                input_ids = tokenizer.encode(prompt + s, return_tensors="pt").to(device)
                break

    info_trace: List[Dict[str, float]] = []

    # Generation loop
    cur_ids = input_ids
    for step in range(max_new):
        # Model forward
        if asa_info:
            out = model(cur_ids, return_info=True)
            logits, infos = out
            # infos is list per layer; take last block's light stats if present
            last = infos[-1] if isinstance(infos, list) and len(infos) > 0 else None
            stat = {}
            if isinstance(last, dict):
                # these are CPU tensors in your module; cast to float if present
                for k in ["entropy_mean", "top1freq_mean", "content_read_gamma_mean", "slotspace_gate_mean", "slotspace_delta_norm"]:
                    if k in last and last[k] is not None:
                        try:
                            stat[k] = float(last[k].item())
                        except Exception:
                            pass
            # Store later for debugging
        else:
            logits = model(cur_ids, return_info=False)
            stat = None

        next_logits = logits[0, -1, :]  # [V]

        # Basic constraints
        if step < min_new and eos_token_id >= 0:
            next_logits = next_logits.clone()
            next_logits[eos_token_id] = float("-inf")

        # Anti-repetition (mild, usually good for ASA because content-read is self-referential)
        gen_so_far = cur_ids[0, input_ids.shape[1]:]  # only newly generated, if any
        next_logits = _apply_repetition_penalty(next_logits, gen_so_far, repetition_penalty)
        next_logits = _no_repeat_ngram_ban(next_logits, cur_ids[0], no_repeat_ngram_size)

        # Router-aware rescue (early EOS / excessive branchiness)
        # Use next-token EOS risk; optionally sharpen if branchy.
        tries = 0
        used_temp, used_top_p, used_top_k = temperature, top_p, top_k
        while True:
            l = next_logits
            if used_temp and used_temp > 0:
                l = l / used_temp

            l = _top_k_top_p_filtering(l, top_k=used_top_k, top_p=used_top_p)

            probs = F.softmax(l, dim=-1)
            p_eos = float(probs[eos_token_id].item()) if eos_token_id >= 0 else 0.0
            ent = float(-(probs.clamp_min(1e-12) * probs.clamp_min(1e-12).log()).sum().item())

            # Condition: early EOS risk is too high
            eos_risky = (eos_token_id >= 0) and (step < early_steps) and (p_eos > eos_risk_threshold)

            # Condition: branchy token distribution (optional) -> reduce temperature a bit
            branchy = False
            if branchy_entropy_threshold is not None and step < early_steps:
                branchy = ent > float(branchy_entropy_threshold)

            if (eos_risky or branchy) and rescue_mode == "resample" and tries < max_resample_tries:
                used_temp = min(used_temp, rescue_temp)
                used_top_p = min(used_top_p, rescue_top_p)
                used_top_k = max(used_top_k, rescue_top_k)
                tries += 1
                continue

            # Choose token
            if do_sample:
                next_id = torch.multinomial(probs, num_samples=1)
            else:
                next_id = torch.argmax(probs, dim=-1, keepdim=True)

            break

        # Log trace
        if asa_info:
            rec = {"step": float(step), "token_entropy": float(ent), "p_eos": float(p_eos)}
            if stat:
                for k, v in stat.items():
                    rec[k] = float(v)
            # record rescue adjustments
            rec["temp_used"] = float(used_temp)
            rec["top_p_used"] = float(used_top_p)
            rec["top_k_used"] = float(used_top_k)
            info_trace.append(rec)

        # Append token
        cur_ids = torch.cat([cur_ids, next_id.view(1, 1)], dim=1)

        # Stop on EOS
        if eos_token_id >= 0 and int(next_id.item()) == int(eos_token_id) and step >= min_new:
            break

    generated_ids = cur_ids[:, input_ids.shape[1]:]

    out: Dict[str, Any] = {
        "input_ids": cur_ids,
        "generated_ids": generated_ids,
    }
    if asa_info:
        out["info_trace"] = info_trace

    return_text = bool(gen.get("return_text", tokenizer is not None))
    if return_text and tokenizer is not None:
        out["text"] = tokenizer.decode(cur_ids[0].tolist(), skip_special_tokens=False)

    return out


# =========================
# PATCH 1: wrappers for your crafted asa_generate
# =========================

@torch.no_grad()

def asa_greedy_suffix(
    prompt: str,
    model: torch.nn.Module,
    gen: dict,
    max_new_tokens: int = 8,
    strip: bool = True,
) -> str:
    """
    Runs your asa_generate in greedy mode and returns ONLY the suffix after `prompt`.
    This is what you want for exact-match checks / scoring.
    """
    # Copy gen so we can override safely
    g = dict(gen)
    g["do_sample"] = False
    g["max_new_tokens"] = int(max_new_tokens)

    out = asa_generate(prompt, model, g)
    text = out.get("text", None)
    if text is None:
        # Fallback: decode manually
        tok = g.get("tokenizer", None)
        if tok is None:
            raise ValueError("No decoded text available; provide gen['tokenizer'].")
        text = tok.decode(out["input_ids"][0].tolist(), skip_special_tokens=False)

    # Suffix (best-effort): if prompt string matches prefix of decoded text
    if text.startswith(prompt):
        suf = text[len(prompt):]
    else:
        # Robust fallback: try to locate the prompt inside the decoded text
        idx = text.find(prompt)
        suf = text[idx + len(prompt):] if idx >= 0 else text

    if strip:
        suf = suf.replace("\n", " ").strip()
    return suf


@torch.no_grad()

def asa_generate_many(
    prompts: list,
    model: torch.nn.Module,
    gen: dict,
    do_sample: bool = False,
    max_new_tokens: int = 8,
) -> list:
    """
    Convenience wrapper: runs asa_generate per prompt (loop) and returns decoded texts.
    """
    g = dict(gen)
    g["do_sample"] = bool(do_sample)
    g["max_new_tokens"] = int(max_new_tokens)

    outs = []
    for p in prompts:
        out = asa_generate(p, model, g)
        text = out.get("text", None)
        if text is None:
            tok = g.get("tokenizer", None)
            if tok is None:
                raise ValueError("No decoded text available; provide gen['tokenizer'].")
            text = tok.decode(out["input_ids"][0].tolist(), skip_special_tokens=False)
        outs.append(text)
    return outs

@torch.no_grad()

def score_next_token_rank(
    prompt: str,
    target_token: str,
    model: torch.nn.Module,
    gen: dict,
) -> dict:
    """
    Computes P(target) and rank for the *next token* only, matching your printed diagnostics.
    """
    tok = gen["tokenizer"]
    device = next(model.parameters()).device

    # encode prompt
    input_ids = tok.encode(prompt, return_tensors="pt").to(device)

    # encode target token as a single token (best-effort)
    target_ids = tok.encode(target_token, add_special_tokens=False)
    if len(target_ids) != 1:
        return {"ok": False, "reason": f"target_token maps to {len(target_ids)} tokens", "target_ids": target_ids}

    target_id = target_ids[0]

    model.eval()
    logits = model(input_ids)  # if your model needs return_info=False default
    if isinstance(logits, (tuple, list)):
        logits = logits[0]
    next_logits = logits[0, -1, :]

    probs = torch.softmax(next_logits, dim=-1)
    p_t = float(probs[target_id].item())

    # rank: 1 = best
    sorted_idx = torch.argsort(next_logits, descending=True)
    rank = int((sorted_idx == target_id).nonzero(as_tuple=False).item()) + 1

    return {"ok": True, "p_target": p_t, "rank": rank, "target_id": target_id}

#




## Section 4 — Sample generations (pre-finetune)
Runs a small batch of prompts to show baseline text generation behavior.

In [None]:
gener = dict(
    tokenizer=tokenizer,
    max_new_tokens=32,
    temperature=0.1,
    top_p=0.95,
    top_k=80,
    repetition_penalty=1.03,
    no_repeat_ngram_size=3,
    asa_info=False,
    rescue_mode=None,
)

print('#' * 5, 'Countries', '#' * 5)
finishers = ['is', 'sounds like', 'consists of', 'is a form of', 'all changed when']
qualities = ['capital', 'language', 'geography', 'government', 'history']
countries = ['France', 'Spain', 'Russia', 'Italy', 'Japan', 'Egypt', 'Germany', 'Brazil']
for country in countries:
    for quality, finisher in zip(qualities, finishers):
        out = asa_generate(f'The {quality} of {country} {finisher}', model, gener)
        print(out['text'])

print('#' * 5, 'People', '#' * 5)
people = ['Albert Einstein', 'George Patton', 'Charles Darwin', 'George Washington', 'Winston Churchill']
factoids = ['was born', 'contributed', 'accomplished', 'had a strong opinion about', 'died']
for person in people:
    for factoid in factoids:
        out = asa_generate(f'{person} {factoid}', model, gener)
        print(out['text'])


## Section 5 — Synthetic mini-alignment dataset
Builds a large prompt/completion set (capitals, languages, currencies, etc.) used for the mini finetune loop.

In [None]:
def is_single_token(s: str, tokenizer) -> bool:
    ids = tokenizer.encode(s, add_special_tokens=False)
    return len(ids) == 1

def build_pairs_expanded(tokenizer):
    """
    Massively expanded dataset generation with WikiText-103 style templates.
    Includes geographical, historical, scientific, cultural, and biographical facts.
    """
    pairs = []

    # ========================================
    # GEOGRAPHY SECTION (Massively Expanded)
    # ========================================

    # ---- Capitals (comprehensive list)
    capitals = {
        # Europe
        "France": " Paris",
        "Germany": " Berlin",
        "Italy": " Rome",
        "Spain": " Madrid",
        "Portugal": " Lisbon",
        "Greece": " Athens",
        "Austria": " Vienna",
        "Poland": " Warsaw",
        "Norway": " Oslo",
        "Sweden": " Stockholm",
        "Finland": " Helsinki",
        "Denmark": " Copenhagen",
        "Ireland": " Dublin",
        "Belgium": " Brussels",
        "Netherlands": " Amsterdam",
        "Switzerland": " Bern",
        "Czech Republic": " Prague",
        "Hungary": " Budapest",
        "Romania": " Bucharest",
        "Bulgaria": " Sofia",
        "Croatia": " Zagreb",
        "Serbia": " Belgrade",
        "Slovakia": " Bratislava",
        "Slovenia": " Ljubljana",
        "Lithuania": " Vilnius",
        "Latvia": " Riga",
        "Estonia": " Tallinn",
        "Iceland": " Reykjavik",
        "Luxembourg": " Luxembourg",
        "Malta": " Valletta",
        "Cyprus": " Nicosia",

        # Asia
        "Japan": " Tokyo",
        "China": " Beijing",
        "India": " Delhi",
        "South Korea": " Seoul",
        "North Korea": " Pyongyang",
        "Thailand": " Bangkok",
        "Vietnam": " Hanoi",
        "Indonesia": " Jakarta",
        "Philippines": " Manila",
        "Malaysia": " Kuala",
        "Singapore": " Singapore",
        "Myanmar": " Naypyidaw",
        "Cambodia": " Phnom",
        "Laos": " Vientiane",
        "Bangladesh": " Dhaka",
        "Pakistan": " Islamabad",
        "Afghanistan": " Kabul",
        "Iran": " Tehran",
        "Iraq": " Baghdad",
        "Saudi Arabia": " Riyadh",
        "Turkey": " Ankara",
        "Israel": " Jerusalem",
        "Jordan": " Amman",
        "Lebanon": " Beirut",
        "Syria": " Damascus",
        "Yemen": " Sanaa",
        "Oman": " Muscat",
        "Kuwait": " Kuwait",
        "Qatar": " Doha",
        "Bahrain": " Manama",
        "United Arab Emirates": " Abu",
        "Nepal": " Kathmandu",
        "Sri Lanka": " Colombo",
        "Mongolia": " Ulaanbaatar",
        "Kazakhstan": " Astana",
        "Uzbekistan": " Tashkent",

        # Africa
        "Egypt": " Cairo",
        "South Africa": " Pretoria",
        "Nigeria": " Abuja",
        "Kenya": " Nairobi",
        "Ethiopia": " Addis",
        "Morocco": " Rabat",
        "Algeria": " Algiers",
        "Tunisia": " Tunis",
        "Libya": " Tripoli",
        "Sudan": " Khartoum",
        "Ghana": " Accra",
        "Tanzania": " Dodoma",
        "Uganda": " Kampala",
        "Angola": " Luanda",
        "Mozambique": " Maputo",
        "Zimbabwe": " Harare",
        "Zambia": " Lusaka",
        "Senegal": " Dakar",
        "Ivory Coast": " Yamoussoukro",
        "Cameroon": " Yaounde",

        # Americas
        "United States": " Washington",
        "Canada": " Ottawa",
        "Mexico": " Mexico",
        "Brazil": " Brasilia",
        "Argentina": " Buenos",
        "Chile": " Santiago",
        "Colombia": " Bogota",
        "Peru": " Lima",
        "Venezuela": " Caracas",
        "Ecuador": " Quito",
        "Bolivia": " La",
        "Paraguay": " Asuncion",
        "Uruguay": " Montevideo",
        "Cuba": " Havana",
        "Jamaica": " Kingston",
        "Costa Rica": " San",
        "Panama": " Panama",
        "Guatemala": " Guatemala",
        "Honduras": " Tegucigalpa",
        "Nicaragua": " Managua",

        # Oceania
        "Australia": " Canberra",
        "New Zealand": " Wellington",
        "Papua New Guinea": " Port",
        "Fiji": " Suva",

        # Former USSR
        "Russia": " Moscow",
        "Ukraine": " Kyiv",
        "Belarus": " Minsk",
        "Georgia": " Tbilisi",
        "Armenia": " Yerevan",
        "Azerbaijan": " Baku",
    }

    for c, cap in capitals.items():
        pairs.append({"prompt": f"The capital of {c} is", "completion": cap, "tag": f"capital:{c}"})
        pairs.append({"prompt": f"{c}'s capital city is", "completion": cap, "tag": f"capital:{c}"})
        pairs.append({"prompt": f"{c} has its capital in", "completion": cap, "tag": f"capital:{c}"})

    # ---- Languages (comprehensive)
    languages = {
        "France": " French",
        "Germany": " German",
        "Italy": " Italian",
        "Japan": " Japanese",
        "Spain": " Spanish",
        "Russia": " Russian",
        "Brazil": " Portuguese",
        "Portugal": " Portuguese",
        "Egypt": " Arabic",
        "China": " Chinese",
        "India": " Hindi",
        "Mexico": " Spanish",
        "Argentina": " Spanish",
        "Netherlands": " Dutch",
        "Greece": " Greek",
        "Poland": " Polish",
        "Turkey": " Turkish",
        "Iran": " Persian",
        "Israel": " Hebrew",
        "Sweden": " Swedish",
        "Norway": " Norwegian",
        "Denmark": " Danish",
        "Finland": " Finnish",
        "Czech Republic": " Czech",
        "Hungary": " Hungarian",
        "Romania": " Romanian",
        "Thailand": " Thai",
        "Vietnam": " Vietnamese",
        "South Korea": " Korean",
    }
    for c, lang in languages.items():
        pairs.append({"prompt": f"The language of {c} is", "completion": lang, "tag": f"language:{c}"})
        pairs.append({"prompt": f"The official language of {c} is", "completion": lang, "tag": f"language:{c}"})
        pairs.append({"prompt": f"People in {c} speak", "completion": lang, "tag": f"language:{c}"})

    # ---- Currencies (comprehensive)
    currencies = {
        "Japan": " yen",
        "Russia": " ruble",
        "India": " rupee",
        "Mexico": " peso",
        "China": " yuan",
        "United Kingdom": " pound",
        "United States": " dollar",
        "Canada": " dollar",
        "Australia": " dollar",
        "Germany": " euro",
        "France": " euro",
        "Italy": " euro",
        "Spain": " euro",
        "Portugal": " euro",
        "Greece": " euro",
        "Austria": " euro",
        "Netherlands": " euro",
        "Belgium": " euro",
        "Poland": " zloty",
        "Czech Republic": " koruna",
        "Sweden": " krona",
        "Norway": " krone",
        "Denmark": " krone",
        "Switzerland": " franc",
        "Brazil": " real",
        "South Africa": " rand",
        "Turkey": " lira",
        "Thailand": " baht",
        "Indonesia": " rupiah",
    }
    for c, cur in currencies.items():
        pairs.append({"prompt": f"The currency of {c} is the", "completion": cur, "tag": f"currency:{c}"})
        pairs.append({"prompt": f"{c} uses the", "completion": cur, "tag": f"currency:{c}"})

    # ---- Continents (expanded with variations)
    continents = {
        "France": " Europe",
        "Germany": " Europe",
        "Italy": " Europe",
        "Spain": " Europe",
        "Poland": " Europe",
        "Greece": " Europe",
        "Sweden": " Europe",
        "Norway": " Europe",
        "Russia": " Europe",
        "Egypt": " Africa",
        "Nigeria": " Africa",
        "Kenya": " Africa",
        "South Africa": " Africa",
        "Morocco": " Africa",
        "Japan": " Asia",
        "China": " Asia",
        "India": " Asia",
        "Thailand": " Asia",
        "Vietnam": " Asia",
        "Indonesia": " Asia",
        "Brazil": " South",
        "Argentina": " South",
        "Chile": " South",
        "Peru": " South",
        "Colombia": " South",
        "Canada": " North",
        "United States": " North",
        "Mexico": " North",
        "Australia": " Oceania",
        "New Zealand": " Oceania",
    }
    for c, cont in continents.items():
        pairs.append({"prompt": f"{c} is in", "completion": cont, "tag": f"continent:{c}"})
        pairs.append({"prompt": f"{c} is located in", "completion": cont, "tag": f"continent:{c}"})

    # ---- Major Rivers
    rivers = {
        "The Nile flows through": " Egypt",
        "The Amazon flows through": " Brazil",
        "The Thames flows through": " London",
        "The Seine flows through": " Paris",
        "The Danube flows through": " Europe",
        "The Rhine flows through": " Germany",
        "The Ganges flows through": " India",
        "The Yangtze flows through": " China",
        "The Mississippi flows through": " America",
        "The Nile is located in": " Africa",
        "The Amazon is in": " South",
        "The Rhine is in": " Europe",
    }
    for p, comp in rivers.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "rivers"})

    # ---- Mountain ranges and peaks
    mountains = {
        "Mount Everest is in": " Nepal",
        "The Alps are in": " Europe",
        "The Himalayas are in": " Asia",
        "The Andes are in": " South",
        "The Rocky Mountains are in": " North",
        "Mount Fuji is in": " Japan",
        "The Pyrenees are between France and": " Spain",
    }
    for p, comp in mountains.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "mountains"})

    # ---- Oceans and seas
    oceans = {
        "The Pacific Ocean is the": " largest",
        "The Atlantic Ocean is the": " second",
        "The Mediterranean Sea is in": " Europe",
        "The Caribbean Sea is in": " Central",
        "The Baltic Sea is in": " Northern",
    }
    for p, comp in oceans.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "oceans"})

    # ========================================
    # HISTORICAL FACTS (Massively Expanded)
    # ========================================

    # ---- Birth locations (expanded)
    born_in = {
        "Albert Einstein was born in": " Germany",
        "Charles Darwin was born in": " England",
        "George Washington was born in": " Virginia",
        "Winston Churchill was born in": " England",
        "Napoleon Bonaparte was born in": " Corsica",
        "Leonardo da Vinci was born in": " Italy",
        "William Shakespeare was born in": " England",
        "Isaac Newton was born in": " England",
        "Marie Curie was born in": " Poland",
        "Galileo Galilei was born in": " Italy",
        "Aristotle was born in": " Greece",
        "Plato was born in": " Athens",
        "Confucius was born in": " China",
        "Buddha was born in": " Nepal",
        "Muhammad Ali was born in": " Kentucky",
        "Martin Luther King was born in": " Georgia",
        "Abraham Lincoln was born in": " Kentucky",
        "Thomas Edison was born in": " Ohio",
        "Nikola Tesla was born in": " Croatia",
        "Sigmund Freud was born in": " Czech",
    }
    for p, comp in born_in.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "born_in"})

    # ---- Death years (single token years)
    death_years = {
        "Albert Einstein died in": " 1955",
        "Isaac Newton died in": " 1727",
        "Charles Darwin died in": " 1882",
        "Leonardo da Vinci died in": " 1519",
        "William Shakespeare died in": " 1616",
        "George Washington died in": " 1799",
        "Napoleon Bonaparte died in": " 1821",
        "Abraham Lincoln died in": " 1865",
        "Marie Curie died in": " 1934",
        "Nikola Tesla died in": " 1943",
    }
    for p, comp in death_years.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "death_year"})

    # ---- Historical events and dates
    historical_events = {
        "World War I began in": " 1914",
        "World War II began in": " 1939",
        "World War II ended in": " 1945",
        "The American Revolution began in": " 1775",
        "The French Revolution began in": " 1789",
        "The Russian Revolution was in": " 1917",
        "The fall of the Berlin Wall was in": " 1989",
        "The September 11 attacks occurred in": " 2001",
        "The moon landing was in": " 1969",
        "Christopher Columbus sailed in": " 1492",
        "The Declaration of Independence was signed in": " 1776",
        "The Civil War began in": " 1861",
        "The Great Depression began in": " 1929",
        "The Cold War began after": " 1945",
    }
    for p, comp in historical_events.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "historical_event"})

    # ---- Century associations
    centuries = {
        "The Renaissance occurred in the": " 15th",
        "The Industrial Revolution began in the": " 18th",
        "The Enlightenment was in the": " 18th",
        "The Victorian Era was in the": " 19th",
        "World War I was in the": " 20th",
    }
    for p, comp in centuries.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "century"})

    # ---- Leaders and rulers
    leaders = {
        "Julius Caesar was a": " Roman",
        "Alexander the Great was a": " Macedonian",
        "Cleopatra was the queen of": " Egypt",
        "Queen Victoria ruled": " Britain",
        "Napoleon was the emperor of": " France",
        "Peter the Great ruled": " Russia",
        "Elizabeth I was queen of": " England",
        "Henry VIII was king of": " England",
        "Louis XIV was king of": " France",
    }
    for p, comp in leaders.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "leader"})

    # ========================================
    # SCIENTIFIC FACTS (Massively Expanded)
    # ========================================

    # ---- Physics facts
    physics = {
        "The speed of light is approximately": " 300",
        "Gravity was discovered by": " Newton",
        "Einstein developed the theory of": " relativity",
        "The atomic bomb was developed during": " World",
        "Newton's laws describe": " motion",
        "Electrons have a": " negative",
        "Protons have a": " positive",
        "The Earth orbits the": " Sun",
        "The Moon orbits the": " Earth",
    }
    for p, comp in physics.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "physics"})

    # ---- Chemistry facts
    chemistry = {
        "Water is composed of hydrogen and": " oxygen",
        "The symbol for gold is": " Au",
        "The symbol for silver is": " Ag",
        "The symbol for iron is": " Fe",
        "The periodic table was created by": " Mendeleev",
        "Oxygen has atomic number": " 8",
        "Carbon has atomic number": " 6",
        "Hydrogen has atomic number": " 1",
        "Salt is composed of sodium and": " chloride",
    }
    for p, comp in chemistry.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "chemistry"})

    # ---- Biology facts
    biology = {
        "DNA stands for deoxyribonucleic": " acid",
        "Photosynthesis occurs in": " plants",
        "The heart pumps": " blood",
        "Humans have": " 46",
        "Evolution was proposed by": " Darwin",
        "Cells are the basic unit of": " life",
        "Mitochondria produce": " energy",
        "The largest organ is the": " skin",
    }
    for p, comp in biology.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "biology"})

    # ---- Mathematics facts
    mathematics = {
        "Pi is approximately": " 3",
        "A triangle has": " three",
        "A square has": " four",
        "A circle has": " 360",
        "The Pythagorean theorem relates": " triangles",
        "Calculus was invented by": " Newton",
        "Algebra originated in": " ancient",
    }
    for p, comp in mathematics.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "mathematics"})

    # ---- Astronomy facts
    astronomy = {
        "The Sun is a": " star",
        "Jupiter is a": " gas",
        "Mars is the": " red",
        "Saturn has": " rings",
        "The Solar System has": " eight",
        "The Milky Way is a": " galaxy",
        "A light year measures": " distance",
        "The nearest star to Earth is the": " Sun",
        "Pluto was reclassified as a": " dwarf",
    }
    for p, comp in astronomy.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "astronomy"})

    # ========================================
    # CULTURAL FACTS (Massively Expanded)
    # ========================================

    # ---- Literature and authors
    literature = {
        "Shakespeare wrote": " Hamlet",
        "Homer wrote the": " Odyssey",
        "Tolkien wrote The Lord of the": " Rings",
        "George Orwell wrote": " 1984",
        "Jane Austen wrote Pride and": " Prejudice",
        "Mark Twain wrote The Adventures of": " Tom",
        "Charles Dickens wrote A Tale of": " Two",
        "Ernest Hemingway wrote The Old Man and the": " Sea",
        "F. Scott Fitzgerald wrote The Great": " Gatsby",
        "Leo Tolstoy wrote War and": " Peace",
        "Fyodor Dostoevsky wrote Crime and": " Punishment",
        "Victor Hugo wrote Les": " Miserables",
        "Miguel de Cervantes wrote Don": " Quixote",
        "Dante wrote The Divine": " Comedy",
        "Virgil wrote the": " Aeneid",
    }
    for p, comp in literature.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "literature"})

    # ---- Art and artists
    art = {
        "Leonardo da Vinci painted the Mona": " Lisa",
        "Vincent van Gogh painted Starry": " Night",
        "Pablo Picasso was a": " Spanish",
        "Michelangelo painted the Sistine": " Chapel",
        "Claude Monet was an": " Impressionist",
        "Salvador Dali was a": " Surrealist",
        "Rembrandt was a": " Dutch",
        "Andy Warhol was a": " Pop",
    }
    for p, comp in art.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "art"})

    # ---- Music and composers
    music = {
        "Mozart was a": " composer",
        "Beethoven wrote": " symphonies",
        "Bach was a": " Baroque",
        "Chopin was a": " Polish",
        "Tchaikovsky was a": " Russian",
        "Wagner was a": " German",
        "Vivaldi wrote The Four": " Seasons",
        "Handel wrote": " Messiah",
    }
    for p, comp in music.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "music"})

    # ---- Sports facts
    sports = {
        "The Olympics originated in": " Greece",
        "Soccer is called football in": " Europe",
        "Basketball was invented in": " America",
        "Baseball is popular in": " America",
        "Cricket is popular in": " India",
        "The World Cup is held every": " four",
        "Tennis is played on a": " court",
        "Golf is played on a": " course",
    }
    for p, comp in sports.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "sports"})

    # ========================================
    # TECHNOLOGY AND INVENTIONS
    # ========================================

    technology = {
        "The telephone was invented by": " Bell",
        "The light bulb was invented by": " Edison",
        "The airplane was invented by the Wright": " Brothers",
        "The printing press was invented by": " Gutenberg",
        "The steam engine was invented by": " Watt",
        "The radio was invented by": " Marconi",
        "The computer was invented in the": " 20th",
        "The internet was developed in": " America",
        "Apple was founded by Steve": " Jobs",
        "Microsoft was founded by Bill": " Gates",
        "Facebook was founded by Mark": " Zuckerberg",
    }
    for p, comp in technology.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "technology"})

    # ========================================
    # ARCHITECTURE AND LANDMARKS
    # ========================================

    landmarks = {
        "The Eiffel Tower is in": " Paris",
        "The Colosseum is in": " Rome",
        "The Taj Mahal is in": " India",
        "The Great Wall is in": " China",
        "The Statue of Liberty is in": " New",
        "Big Ben is in": " London",
        "The Pyramids are in": " Egypt",
        "The Parthenon is in": " Athens",
        "The Kremlin is in": " Moscow",
        "Machu Picchu is in": " Peru",
        "Petra is in": " Jordan",
        "Angkor Wat is in": " Cambodia",
        "The Sydney Opera House is in": " Australia",
    }
    for p, comp in landmarks.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "landmarks"})

    # ========================================
    # ANIMALS AND NATURE
    # ========================================

    animals = {
        "The largest animal is the blue": " whale",
        "The fastest land animal is the": " cheetah",
        "The tallest animal is the": " giraffe",
        "Lions are found in": " Africa",
        "Pandas are native to": " China",
        "Kangaroos are native to": " Australia",
        "Penguins live in": " Antarctica",
        "Tigers are native to": " Asia",
        "Elephants are found in": " Africa",
        "Polar bears live in the": " Arctic",
    }
    for p, comp in animals.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "animals"})

    # ========================================
    # RELIGIONS AND MYTHOLOGY
    # ========================================

    religions = {
        "Christianity originated in": " Israel",
        "Islam originated in": " Saudi",
        "Buddhism originated in": " India",
        "Hinduism originated in": " India",
        "Judaism originated in": " Israel",
        "The Bible is the holy book of": " Christianity",
        "The Quran is the holy book of": " Islam",
        "Zeus was the king of the": " Greek",
        "Thor was a": " Norse",
        "Ra was an": " Egyptian",
    }
    for p, comp in religions.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "religion"})

    # ========================================
    # FOOD AND CUISINE
    # ========================================

    cuisine = {
        "Pizza originated in": " Italy",
        "Sushi originated in": " Japan",
        "Tacos originated in": " Mexico",
        "Hamburgers are popular in": " America",
        "Pasta is from": " Italy",
        "Croissants are from": " France",
        "Curry is from": " India",
        "Paella is from": " Spain",
        "Kimchi is from": " Korea",
    }
    for p, comp in cuisine.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "cuisine"})

    # ========================================
    # ECONOMIC AND POLITICAL FACTS
    # ========================================

    economics = {
        "The largest economy is": " America",
        "The European Union uses the": " euro",
        "OPEC stands for Organization of": " Petroleum",
        "The World Bank is headquartered in": " Washington",
        "The United Nations is headquartered in": " New",
        "NATO stands for North Atlantic": " Treaty",
        "GDP stands for Gross Domestic": " Product",
    }
    for p, comp in economics.items():
        pairs.append({"prompt": p, "completion": comp, "tag": "economics"})

    # ---- Filter to single-token completions
    kept = []
    dropped = []
    for ex in pairs:
        if is_single_token(ex["completion"], tokenizer):
            kept.append(ex)
        else:
            dropped.append(ex)

    # ---- Basic reporting
    from collections import Counter
    counts = Counter(ex["tag"].split(":")[0] for ex in kept)
    print("Kept counts by task:", dict(counts))
    print(f"\nTotal generated pairs: {len(pairs)}")
    print(f"Single-token completions: {len(kept)}")
    print(f"Multi-token completions (dropped): {len(dropped)}")

    if dropped:
        print(f"\nShowing first 20 dropped (multi-token) examples:")
        for ex in dropped[:20]:
            ids = tokenizer.encode(ex["completion"], add_special_tokens=False)
            print(f"  {ex['tag']:<20} {ex['prompt']:<50} -> {repr(ex['completion']):<20} token_ids={ids}")

    return kept

tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
pairs = build_pairs_expanded(tokenizer)



pairs = build_pairs_expanded(tokenizer)
print('Synthetic pairs:', len(pairs))


## Section 6 — Finetune configuration and utilities
Configures the synthetic/WikiText mix, builds dataloaders, and defines evaluation helpers.

In [None]:
DO_FINETUNE = True  # set False if you want to skip this section

CFG = {
    # mix in WikiText
    "use_wiki": False,
    "wiki_dataset_name": "wikitext",
    "wiki_config_candidates": ["wikitext-103-raw-v1", "wikitext-2-raw-v1"],  # fallback
    "wiki_num_samples": 1536,         # number of wiki chunks (not lines)
    "wiki_chunk_chars_min": 400,      # filter small chunks
    "wiki_chunk_chars_max": 1200,     # chunk size (chars) before tokenization

    # training
    "max_len": 128,                   # increased since wiki chunks are longer
    "batch_size": 16,
    "steps": 77,
    "lr": 7e-6,
    "weight_decay": 0.007,
    "grad_clip": 1.0,

    # finetune mode
    # "all" trains everything; "slot_attn_only" freezes everything except slot-space attention op
    "finetune_mode": "slot_attn_only",  # or "all" or  "slot_attn_only"

    # which params count as "slot attention" (adjust to your module names)
    #"slot_train_name_regex": r"(slot|slots).*(attn|attention)|((attn|attention).*(slot|slots))",

    "slot_train_name_regex":r"(^|\.)(slot_in|slot_q|slot_k|slot_v|slot_out)\.weight$|(^|\.)(_slotspace_gate_raw)$",


}

# -----------------------
# 1) Utilities (aligned to your model call style)
# -----------------------
@torch.no_grad()
def next_token_stats(prompt: str, target_token_str: str, model, tokenizer):
    model.eval()
    inp = tokenizer.encode(prompt, return_tensors="pt").to(device)

    tgt_ids = tokenizer.encode(target_token_str, add_special_tokens=False)
    if len(tgt_ids) != 1:
        return {"ok": False, "reason": f"target string maps to {len(tgt_ids)} tokens", "target_ids": tgt_ids}

    tgt = tgt_ids[0]

    out = model(inp)
    logits = out[0] if isinstance(out, (tuple, list)) else out
    last = logits[0, -1, :]
    probs = torch.softmax(last, dim=-1)

    p = float(probs[tgt].item())
    rank = int((torch.argsort(last, descending=True) == tgt).nonzero(as_tuple=False).item()) + 1
    top1_id = int(torch.argmax(last).item())
    top1 = tokenizer.decode([top1_id])

    return {"ok": True, "p_target": p, "rank": rank, "top1": top1, "target_id": tgt}

@torch.no_grad()
def greedy_suffix(prompt: str, model, gen, max_new_tokens=8):
    g = dict(gen)
    g["do_sample"] = False
    g["max_new_tokens"] = int(max_new_tokens)
    out = asa_generate(prompt, model, g)
    text = out["text"]
    if text.startswith(prompt):
        return text[len(prompt):].replace("\n", " ").strip()
    idx = text.find(prompt)
    if idx >= 0:
        return text[idx+len(prompt):].replace("\n", " ").strip()
    return text.replace("\n", " ").strip()

@torch.no_grad()
def eval_exact_match(examples, model, gen, max_new_tokens=8):
    model.eval()
    ok = 0
    for ex in examples:
        pred = greedy_suffix(ex["prompt"], model, gen, max_new_tokens=max_new_tokens)
        gold = ex["completion"].replace("\n", " ").strip()
        ok += int(pred.startswith(gold))
    return ok / max(1, len(examples))

# -----------------------
# 2) Dataset builders
# -----------------------


def load_wikitext_chunks(tokenizer, num_samples=2048, chunk_chars_min=400, chunk_chars_max=1200):
    """
    Produces wiki training examples as plain LM text chunks:
      ex = {"prompt": "", "completion": "<wiki chunk>", "tag": "wiki"}
    We chunk by chars first, then token-truncate later in dataset.
    """
    try:
        from datasets import load_dataset
    except Exception as e:
        print("[WikiText] datasets not available; skipping WikiText mix.")
        return []

    ds = None
    used_cfg = None
    for cfg in CFG["wiki_config_candidates"]:
        try:
            ds = load_dataset(CFG["wiki_dataset_name"], cfg, split="train")
            used_cfg = cfg
            break
        except Exception:
            ds = None

    if ds is None:
        print("[WikiText] Could not load WikiText (tried configs:", CFG["wiki_config_candidates"], "). Skipping.")
        return []

    print(f"[WikiText] Loaded {CFG['wiki_dataset_name']} / {used_cfg} train split with {len(ds)} rows.")

    # Pull raw text field (wikitext uses 'text')
    texts = [t for t in ds["text"] if isinstance(t, str) and len(t.strip()) > 0]

    # Make chunks: concatenate consecutive lines until size bound, filter small chunks
    chunks = []
    buf = []
    buf_len = 0

    # shuffle deterministically
    rng = random.Random(SEED)
    rng.shuffle(texts)

    for line in texts:
        line = line.strip()
        # skip headings markup lines; keep normal prose
        if line.startswith("=") and line.endswith("="):
            continue
        if not line:
            continue

        # add line to buffer
        if buf_len + len(line) + 1 <= chunk_chars_max:
            buf.append(line)
            buf_len += len(line) + 1
        else:
            chunk = " ".join(buf).strip()
            if len(chunk) >= chunk_chars_min:
                chunks.append(chunk)
            buf = [line]
            buf_len = len(line) + 1

        if len(chunks) >= num_samples:
            break

    # flush
    if len(chunks) < num_samples:
        chunk = " ".join(buf).strip()
        if len(chunk) >= chunk_chars_min:
            chunks.append(chunk)

    # create examples
    wiki_examples = [{"prompt": "", "completion": c, "tag": "wiki"} for c in chunks[:num_samples]]
    print(f"[WikiText] Prepared {len(wiki_examples)} wiki chunks.")

    return wiki_examples

# -----------------------
# 3) Split synthetic (entity-holdout) + build mixed train set
# -----------------------
#pairs = build_pairs_expanded(tokenizer)

from collections import Counter

holdout_capitals = {
    "Spain", "Canada", "Poland", "Portugal", "Greece", "Austria",
    "Norway", "Ireland", "Romania", "Croatia", "Argentina", "Chile"
}
holdout_languages = {"Brazil", "Mexico", "Netherlands", "Sweden", "Finland"}
holdout_currencies = {"Japan", "Switzerland", "South Africa", "Thailand"}
holdout_continents = {"Kenya", "Vietnam", "Peru", "New Zealand"}

train_examples, holdout_examples = [], []
for ex in pairs:
    task = ex["tag"].split(":")[0]

    if task == "capital":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_capitals else train_examples).append(ex)
    elif task == "language":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_languages else train_examples).append(ex)
    elif task == "currency":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_currencies else train_examples).append(ex)
    elif task == "continent":
        country = ex["tag"].split(":", 1)[1]
        (holdout_examples if country in holdout_continents else train_examples).append(ex)
    else:
        if random.random() < 0.2:
            holdout_examples.append(ex)
        else:
            train_examples.append(ex)

print(f"\n[Synthetic] Total kept pairs: {len(pairs)} | Train: {len(train_examples)} | Holdout: {len(holdout_examples)}")
print(f"[Synthetic] Split: {len(train_examples)/len(pairs)*100:.1f}% / {len(holdout_examples)/len(pairs)*100:.1f}%")
holdout_by_cat = Counter(ex["tag"].split(":")[0] for ex in holdout_examples)
print("[Synthetic] Holdout by category:", dict(holdout_by_cat))

# NEW: load wiki and mix into TRAIN ONLY
wiki_examples = []
if CFG["use_wiki"]:
    wiki_examples = load_wikitext_chunks(
        tokenizer,
        num_samples=CFG["wiki_num_samples"],
        chunk_chars_min=CFG["wiki_chunk_chars_min"],
        chunk_chars_max=CFG["wiki_chunk_chars_max"],
    )

mixed_train_examples = train_examples + wiki_examples
print(f"\n[Mix] Train synthetic={len(train_examples)} + wiki={len(wiki_examples)} => mixed_train={len(mixed_train_examples)}")
print(f"[Mix] Holdout (synthetic only) = {len(holdout_examples)}")

# -----------------------
# 4) Tiny finetune dataset (teacher forcing)
#    Works for BOTH: prompt+completion pairs and raw wiki chunks (prompt="")
# -----------------------
class PromptCompletionDataset(Dataset):
    def __init__(self, examples, tokenizer, max_len=128):
        self.examples = examples
        self.tok = tokenizer
        self.max_len = int(max_len)

    def __len__(self): return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        text = ex["prompt"] + ex["completion"]
        ids = self.tok.encode(text)

        # keep the tail; for wiki, this acts like "random suffix LM"
        ids = ids[-self.max_len:]

        x = torch.tensor(ids[:-1], dtype=torch.long)
        y = torch.tensor(ids[1:], dtype=torch.long)
        return x, y, ex

def collate_pad(batch):
    xs, ys, exs = zip(*batch)
    maxT = max(x.size(0) for x in xs)
    pad_id = tokenizer.eos_token_id  # GPT-2 no pad token

    X = torch.full((len(xs), maxT), pad_id, dtype=torch.long)
    Y = torch.full((len(xs), maxT), -100, dtype=torch.long)

    for i, (x, y) in enumerate(zip(xs, ys)):
        T = x.size(0)
        X[i, :T] = x
        Y[i, :T] = y
    return X.to(device), Y.to(device), exs

train_ds = PromptCompletionDataset(mixed_train_examples, tokenizer, max_len=CFG["max_len"])
train_dl = DataLoader(
    train_ds,
    batch_size=min(CFG["batch_size"], len(train_ds)),
    shuffle=True,
    collate_fn=collate_pad
)

# -----------------------
# 5) Optional: freeze everything except slot-space attention (NEW)
# -----------------------
def configure_finetune_mode(model, mode: str, name_regex: str):
    """
    mode:
      - "all": train everything
      - "slot_attn_only": only train parameters whose full name matches `name_regex`
    """
    if mode == "all":
        for p in model.parameters():
            p.requires_grad = True
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"[Finetune] mode=all trainable={trainable}/{total} ({trainable/total*100:.2f}%)")
        return

    if mode != "slot_attn_only":
        raise ValueError(f"Unknown finetune_mode={mode}")

    rx = re.compile(name_regex, flags=re.IGNORECASE)

    # freeze everything
    for _, p in model.named_parameters():
        p.requires_grad = False

    # unfreeze matching params
    matched = []
    for n, p in model.named_parameters():
        if rx.search(n) is not None:
            p.requires_grad = True
            matched.append((n, p.numel()))

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"[Finetune] mode=slot_attn_only regex={name_regex!r}")
    print(f"[Finetune] trainable={trainable}/{total} ({trainable/total*100:.4f}%) matched_tensors={len(matched)}")

    # show top matches by size
    matched.sort(key=lambda x: -x[1])
    for n, k in matched[:25]:
        print(f"  [trainable] {k:>10}  {n}")

configure_finetune_mode(model, CFG["finetune_mode"], CFG["slot_train_name_regex"])

# -----------------------
# 6) Pre-eval (synthetic only, as before)
# -----------------------
print("\n" + "="*80)
print("PRE-TRAINING EVALUATION (synthetic only)")
print("="*80)

# FIX: Enable asa_info to handle model's tuple return type correctly
gener['asa_info'] = True

pre_acc_train = eval_exact_match(train_examples, model, gener, max_new_tokens=8)
pre_acc_hold  = eval_exact_match(holdout_examples, model, gener, max_new_tokens=8)

print(f"\n[PRE] Exact-match accuracy:")
print(f"  Train:   {pre_acc_train:.3f} ({int(pre_acc_train*len(train_examples))}/{len(train_examples)})")
print(f"  Holdout: {pre_acc_hold:.3f} ({int(pre_acc_hold*len(holdout_examples))}/{len(holdout_examples)})")

print("\n[PRE] Next-token stats for sample of single-token targets (synthetic only):")
sample_for_stats = random.sample(pairs, min(30, len(pairs)))
for ex in sample_for_stats:
    stats = next_token_stats(ex["prompt"], ex["completion"], model, tokenizer)
    if stats["ok"]:
        print(f"  {ex['tag']:<25} P={stats['p_target']:.4f} rank={stats['rank']:>5} top1={stats['top1']!r}")
    else:
        print(f"  {ex['tag']:<25} (skip) {stats['reason']}")

# -----------------------
# 7) Light training (mixed: synthetic + wiki)
# -----------------------
print("\n" + "="*80)
print("TRAINING (mixed synthetic + wiki)")
print("="*80)

model.train()

# IMPORTANT: optimizer must only see trainable params (esp for slot_attn_only)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if len(trainable_params) == 0:
    raise RuntimeError("No trainable parameters. Check CFG['finetune_mode'] and regex.")

opt = torch.optim.AdamW(
    trainable_params,
    lr=CFG["lr"],
    betas=(0.9, 0.95),
    weight_decay=CFG["weight_decay"]
)

steps = int(CFG["steps"])
grad_clip = float(CFG["grad_clip"])

print(f"Training for {steps} steps with batch_size={train_dl.batch_size}")
print(f"Total mixed_train_examples={len(mixed_train_examples)} | synthetic={len(train_examples)} | wiki={len(wiki_examples)}\n")

# stable batch stream (avoid re-instantiating iter(train_dl) each step)
batch_iter = itertools.cycle(train_dl)

for step in range(steps):
    X, Y, _ = next(batch_iter)
    opt.zero_grad(set_to_none=True)

    logits = model(X)
    logits = logits[0] if isinstance(logits, (tuple, list)) else logits

    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        Y.view(-1),
        ignore_index=-100
    )
    loss.backward()
    torch.nn.utils.clip_grad_norm_(trainable_params, grad_clip)
    opt.step()

    if (step + 1) % 50 == 0:
        print(f"  [train] step {step+1:>4}/{steps} loss={float(loss.item()):.4f}")

model.eval()

# -----------------------
# 8) Post-eval (synthetic only, as before)
# -----------------------
print("\n" + "="*80)
print("POST-TRAINING EVALUATION (synthetic only)")
print("="*80)

post_acc_train = eval_exact_match(train_examples, model, gener, max_new_tokens=8)
post_acc_hold  = eval_exact_match(holdout_examples, model, gener, max_new_tokens=8)

print(f"\n[POST] Exact-match accuracy:")
print(f"  Train:   {post_acc_train:.3f} ({int(post_acc_train*len(train_examples))}/{len(train_examples)})")
print(f"  Holdout: {post_acc_hold:.3f} ({int(post_acc_hold*len(holdout_examples))}/{len(holdout_examples)})")

print(f"\n[DELTA] Accuracy change:")
print(f"  Train:   {pre_acc_train:.3f} -> {post_acc_train:.3f} (Δ={post_acc_train-pre_acc_train:+.3f})")
print(f"  Holdout: {pre_acc_hold:.3f} -> {post_acc_hold:.3f} (Δ={post_acc_hold-pre_acc_hold:+.3f})")

print("\n[POST] Next-token stats for same sample (synthetic only):")
for ex in sample_for_stats:
    stats = next_token_stats(ex["prompt"], ex["completion"], model, tokenizer)
    if stats["ok"]:
        print(f"  {ex['tag']:<25} P={stats['p_target']:.4f} rank={stats['rank']:>5} top1={stats['top1']!r}")

# -----------------------
# 9) Generations (synthetic categories only)
# -----------------------
print("\n" + "="*80)
print("GENERATION SAMPLES (greedy decoding) (synthetic only)")
print("="*80)

generation_samples = []
by_category = {}
for ex in pairs:
    cat = ex["tag"].split(":")[0]
    by_category.setdefault(cat, []).append(ex)

for cat, exs in sorted(by_category.items()):
    generation_samples.extend(exs[:2])

generation_samples = generation_samples[:25]

for ex in generation_samples:
    raw = greedy_suffix(ex["prompt"], model, gener, max_new_tokens=12)

    tag_base = ex["tag"].split(":")[0]
    if tag_base == "capital":
        scaffold_prompt = ex["prompt"] + " the city of"
    elif tag_base == "language":
        scaffold_prompt = ex["prompt"] + " primarily"
    elif tag_base == "currency":
        scaffold_prompt = ex["prompt"]
    else:
        scaffold_prompt = ex["prompt"]

    sca = greedy_suffix(scaffold_prompt, model, gener, max_new_tokens=12)

    print(f"\n{'─'*80}")
    print(f"CATEGORY: {ex['tag']:<25} TARGET: {ex['completion']!r}")
    print(f"PROMPT:   {ex['prompt']!r}")
    print(f"RAW:      {raw[:100]}")
    if scaffold_prompt != ex["prompt"]:
        print(f"SCAFFOLD: {sca[:100]}")

print("\n" + "="*80)

## Section 7 — Canon Probes (AFTER finetune)
Re-runs the canon probes after the finetune loop so you can compare margins and routing stats.

In [None]:
if 'DO_FINETUNE' in globals() and DO_FINETUNE:
    model.eval()
    after_results = run_canon_probes(model, 'after_finetune', artifacts_dir)

    comparison = {
        'mean_margin_before': baseline_results['mean_margin'],
        'mean_margin_after': after_results['mean_margin'],
        'margin_deltas': [a-b for a,b in zip(after_results['margins'], baseline_results['margins'])],
        'routing_stats_before': baseline_results.get('routing_stats', {}),
        'routing_stats_after': after_results.get('routing_stats', {}),
    }
    (artifacts_dir / 'comparison.json').write_text(json.dumps(comparison, indent=2))
    print('Before/After mean margin:', comparison['mean_margin_before'], '→', comparison['mean_margin_after'])
else:
    print('Finetune skipped; no after-finetune probe.')


## Section 8 — Optional: Push finetuned artifact to HF
Uploads the finetuned weights if `HF_TOKEN` is set in the environment.

In [None]:
from huggingface_hub import HfApi, upload_file

token = os.environ.get('HF_TOKEN')
if token:
    api = HfApi(token=token)
    try:
        upload_file(
            path_or_fileobj=str(artifacts_dir / 'finetuned' / 'finetuned.pt'),
            path_in_repo='finetuned/finetuned.pt',
            repo_id=HF_REPO,
            repo_type='model',
        )
        print('Uploaded finetuned checkpoint.')
    except Exception as exc:
        print('Upload failed:', exc)
else:
    print('HF_TOKEN not set; skipping upload.')
