# ASA HF Checkpoint + Canon Probes + Mini Fact Finetune

CPU-only Colab notebook: load the HF checkpoint, run canonical probes, optionally
do a tiny finetune on a handful of examples, and re-run probes to show before/after deltas.

**Expected runtime:** a few minutes on CPU.

**Notes:**
- Uses a tiny synthetic QA dataset for the finetune step (can be skipped).
- Saves JSON artifacts in `artifacts/` for quick inspection.


In [None]:
# Section 0 — Setup
import os, sys, subprocess, platform, json, time, random
from pathlib import Path

repo_dir = 'ASA'
if not Path(repo_dir).exists():
    subprocess.run(['git','clone','https://github.com/digitaldaimyo/ASA.git'], check=True)
os.chdir(repo_dir)

subprocess.run([sys.executable,'-m','pip','install','-e','.'], check=True)
subprocess.run([sys.executable,'-m','pip','install','-q','huggingface_hub','safetensors','transformers'], check=True)

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

device = torch.device('cpu')
print('Python:', platform.python_version())
print('Torch:', torch.__version__)
print('Device:', device)
try:
    commit = subprocess.check_output(['git','rev-parse','HEAD']).decode().strip()
    print('Repo commit:', commit)
except Exception:
    print('Repo commit: unavailable')

seed = 1337
random.seed(seed)
torch.manual_seed(seed)

artifacts_dir = Path('artifacts')
artifacts_dir.mkdir(exist_ok=True)


In [None]:
# Section 1 — Load base model from Hugging Face (Baseline)
from asa.load_pretrained import load_pretrained

HF_REPO = 'DigitalShogun/ASA-ASM-wikitext103-raw'
DEFAULT_CKPT = 'ASA_ASM_wt103-rawv1_gpt2_T1024_L21_D384_H8_K16_M32_ropek1_alibi1_gamma1_step75000_best.pt'

model, report, cfg_obj = load_pretrained(HF_REPO, DEFAULT_CKPT, variant='baseline', device='cpu')
print('Loaded model with vocab_size:', cfg_obj.vocab_size)
print('Checkpoint source:', report['state_dict_source'])
print('Allowlisted gaps:', {
    'missing': len(report['allowed_missing']),
    'unexpected': len(report['allowed_unexpected']),
    'mismatched': len(report['allowed_mismatched']),
})

input_ids = torch.randint(0, cfg_obj.vocab_size, (1, 32))
with torch.no_grad():
    logits, _ = model(input_ids)
print('Logits shape:', tuple(logits.shape))
assert logits.shape == (1, 32, cfg_obj.vocab_size)
assert torch.isfinite(logits).all()

run_metadata = {
    'repo': HF_REPO,
    'checkpoint': DEFAULT_CKPT,
    'state_dict_source': report['state_dict_source'],
    'seed': seed,
    'timestamp': time.time(),
    'config': cfg_obj.__dict__,
}
(artifacts_dir / 'run_metadata.json').write_text(json.dumps(run_metadata, indent=2))


In [None]:
# Section 2 — Canon Probes (BEFORE finetune)
tokenizer = AutoTokenizer.from_pretrained('gpt2')

PROMPTS = [
    'The capital of France is',
    "France's capital city is",
    'Paris is the capital of',
    'The capital of the UK is',
    'London is the capital of',
    'A major city in France is',
]

def get_token_id(text):
    ids = tokenizer.encode(text)
    if len(ids) != 1:
        raise ValueError(f'Expected single token for {text}, got {ids}')
    return ids[0]

paris_id = get_token_id(' Paris')
london_id = get_token_id(' London')

def run_canon_probes(model, tag, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    margins = []
    top10 = []
    for prompt in PROMPTS:
        ids = tokenizer.encode(prompt)
        input_ids = torch.tensor([ids])
        with torch.no_grad():
            logits, infos = model(input_ids, return_info=True)
        last = logits[0, -1]
        margin = (last[paris_id] - last[london_id]).item()
        margins.append(margin)
        top_ids = torch.topk(last, k=10).indices.tolist()
        top10.append([tokenizer.decode([i]) for i in top_ids])
    mean_margin = float(sum(margins) / len(margins))
    min_margin = float(min(margins))

    routing_stats = {}
    try:
        sample = torch.randint(0, cfg_obj.vocab_size, (2, 16))
        with torch.no_grad():
            _, info = model(sample, return_info=True)
        if isinstance(info, list) and info:
            info0 = info[0] or {}
        else:
            info0 = info or {}
        if info0.get('read_weights') is not None:
            p = info0['read_weights'].float().clamp_min(1e-8)
            entropy = -(p * p.log()).sum(dim=-1).mean().item()
            top = p.argmax(dim=-1).reshape(-1)
            hist = torch.bincount(top, minlength=p.shape[-1]).float()
            top1freq = (hist.max() / hist.sum().clamp_min(1.0)).item()
            routing_stats['routing_entropy'] = entropy
            routing_stats['routing_top1freq'] = top1freq
        for key in ('content_read_gamma_mean','slotspace_gate_mean','slotspace_delta_norm'):
            if key in info0:
                routing_stats[key] = float(torch.as_tensor(info0[key]).mean().item())
    except Exception as exc:
        routing_stats['error'] = str(exc)

    results = {
        'tag': tag,
        'margins': margins,
        'mean_margin': mean_margin,
        'min_margin': min_margin,
        'top10_tokens': top10,
        'routing_stats': routing_stats,
    }
    out_path = out_dir / f'{tag}_probes.json'
    out_path.write_text(json.dumps(results, indent=2))
    print('Probe summary:', tag)
    print('  mean margin:', mean_margin)
    print('  min margin:', min_margin)
    return results

baseline_results = run_canon_probes(model, 'before_finetune', artifacts_dir)


In [None]:
# Section 3 — Mini Fact-Answering Finetune (Tiny synthetic + WikiText mix)
from torch.utils.data import DataLoader

DO_FINETUNE = True  # set False to skip the finetune step

if DO_FINETUNE:
    examples = [
        ('Q: What is the capital of France?\nA:', ' Paris'),
        ('Q: What is the capital of the UK?\nA:', ' London'),
        ('Q: What city is the capital of Germany?\nA:', ' Berlin'),
        ('Q: What city is the capital of Italy?\nA:', ' Rome'),
    ]

    synthetic = []
    for prompt, answer in examples:
        ids = tokenizer.encode(prompt + answer)
        synthetic.append(torch.tensor(ids, dtype=torch.long))

    max_len = max(len(x) for x in synthetic)
    padded = []
    for seq in synthetic:
        pad = max_len - len(seq)
        if pad > 0:
            seq = torch.cat([seq, torch.full((pad,), tokenizer.eos_token_id, dtype=torch.long)])
        padded.append(seq)

    data = torch.stack(padded)
    loader = DataLoader(data, batch_size=2, shuffle=True)

    model.train()
    optim = torch.optim.AdamW(model.parameters(), lr=5e-5)
    losses = []
    steps = 80
    for step in range(steps):
        batch = next(iter(loader))
        logits, _ = model(batch)
        loss = torch.nn.functional.cross_entropy(
            logits[:, :-1, :].reshape(-1, cfg_obj.vocab_size),
            batch[:, 1:].reshape(-1),
        )
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        losses.append(loss.item())
        if (step + 1) % 20 == 0:
            print(f'step {step+1}/{steps} loss={loss.item():.4f}')

    finetune_dir = artifacts_dir / 'finetuned'
    finetune_dir.mkdir(exist_ok=True)
    torch.save(
        {'model': model.state_dict(), 'cfg': cfg_obj.__dict__, 'losses': losses},
        finetune_dir / 'finetuned.pt',
    )
    (finetune_dir / 'losses.json').write_text(json.dumps(losses, indent=2))
    print('Saved finetuned checkpoint to', finetune_dir)
else:
    print('Skipping finetune step; DO_FINETUNE=False')


In [None]:
# Section 4 — Canon Probes (AFTER finetune)
if DO_FINETUNE:
    model.eval()
    after_results = run_canon_probes(model, 'after_finetune', artifacts_dir)

    comparison = {
        'mean_margin_before': baseline_results['mean_margin'],
        'mean_margin_after': after_results['mean_margin'],
        'margin_deltas': [a-b for a,b in zip(after_results['margins'], baseline_results['margins'])],
        'routing_stats_before': baseline_results.get('routing_stats', {}),
        'routing_stats_after': after_results.get('routing_stats', {}),
    }
    (artifacts_dir / 'comparison.json').write_text(json.dumps(comparison, indent=2))
    print('Before/After mean margin:', comparison['mean_margin_before'], '→', comparison['mean_margin_after'])
else:
    print('Finetune skipped; no after-finetune probe.')


In [None]:
# Section 5 — Optional: Push finetuned artifact to HF (if token present)
from huggingface_hub import HfApi, upload_file

token = os.environ.get('HF_TOKEN')
if token:
    api = HfApi(token=token)
    try:
        upload_file(
            path_or_fileobj=str(artifacts_dir / 'finetuned' / 'finetuned.pt'),
            path_in_repo='finetuned/finetuned.pt',
            repo_id=HF_REPO,
            repo_type='model',
        )
        print('Uploaded finetuned checkpoint.')
    except Exception as exc:
        print('Upload failed:', exc)
else:
    print('HF_TOKEN not set; skipping upload.')
