# ASA Colab Quickstart — Functional Validation

This notebook runs a fast, CPU-only validation pass for the ASA variants.
It checks shapes, gradients, determinism, masking, routing overrides, and intervention toggles.


In [None]:
# Status Dashboard
results = []

def record(name, ok, details=""):
    results.append({"name": name, "ok": bool(ok), "details": details})

def finalize():
    print("\n=== STATUS DASHBOARD ===")
    for item in results:
        status = "PASS" if item["ok"] else "FAIL"
        details = f" - {item['details']}" if item['details'] else ""
        print(f"{status:>4} | {item['name']}{details}")
    overall = all(item["ok"] for item in results) if results else True
    print(f"OVERALL: {'PASS' if overall else 'FAIL'}")
    if not overall:
        raise AssertionError("One or more checks failed.")

def print_section(title):
    print(f"\n=== {title} ===")

def assert_finite(tensor, name):
    if not tensor.isfinite().all():
        raise AssertionError(f"Non-finite values in {name}")


In [None]:
# Section 0 — Environment & Setup
import os
import sys
import subprocess
import platform
import random

print_section('Environment & Setup')
repo_dir = 'ASA'
if not os.path.isdir(repo_dir):
    print('Cloning repo...')
    result = subprocess.run(['git', 'clone', 'https://github.com/digitaldaimyo/ASA.git'])
    if result.returncode != 0:
        print('Failed to clone. Next steps:')
        print(' - Ensure Colab has write permissions to /content')
        print(' - Restart the runtime and try again')
        raise RuntimeError('git clone failed')
if os.path.isdir(repo_dir):
    os.chdir(repo_dir)

print('Installing package (editable)...')
result = subprocess.run([sys.executable, '-m', 'pip', 'install', '-e', '.'])
if result.returncode != 0:
    raise RuntimeError('pip install failed')

import torch
print(f'Python: {platform.python_version()}')
print(f'Torch: {torch.__version__}')
print('Device: cpu')

seed = 1337
random.seed(seed)
torch.manual_seed(seed)
try:
    torch.use_deterministic_algorithms(True)
    print('Deterministic algorithms enabled.')
except Exception as exc:
    print(f'Warning: deterministic algorithms not fully enforced ({exc}).')
record('setup', True, 'environment ready')


In [None]:
# Section 1 — Import & API Surface Validation
print_section('Import & API Surface Validation')
import inspect
from asa import AddressedStateAttention, AddressedStateAttentionOnline, AddressedStateAttentionIntervene

variants = [
    (AddressedStateAttention, 'baseline'),
    (AddressedStateAttentionOnline, 'online'),
    (AddressedStateAttentionIntervene, 'intervene'),
]

try:
    for cls, label in variants:
        sig = inspect.signature(cls)
        for arg in ('embed_dim', 'num_heads', 'num_slots'):
            if arg not in sig.parameters:
                raise AssertionError(f'{label} missing arg: {arg}')
        model = cls(embed_dim=32, num_heads=4, num_slots=8)
        x = torch.randn(2, 5, 32)
        out, info = model(x)
        assert isinstance(out, torch.Tensor)
        assert out.shape == x.shape
        assert info is None or isinstance(info, dict)
    record('api surface', True, 'imports/shape ok')
except Exception as exc:
    record('api surface', False, str(exc))
    raise


In [None]:
# Section 2 — Gradients & Parameter Update Sanity
print_section('Gradients & Parameter Update Sanity')
from torch.nn.utils import parameters_to_vector

def grad_and_update(cls, label):
    model = cls(embed_dim=32, num_heads=4, num_slots=8)
    optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
    x = torch.randn(2, 8, 32)
    vec_before = parameters_to_vector([p.detach().clone() for p in model.parameters()])
    nonzero_grads = 0
    for _ in range(3):
        out, _ = model(x)
        loss = out.pow(2).mean() + 0.01 * out.mean()
        optim.zero_grad()
        loss.backward()
        for p in model.parameters():
            if p.grad is not None and p.grad.detach().abs().sum() > 0:
                nonzero_grads += 1
        optim.step()
    vec_after = parameters_to_vector([p.detach() for p in model.parameters()])
    changed = (vec_after - vec_before).abs().sum().item()
    return nonzero_grads, changed

try:
    for cls, label in variants:
        grads, delta = grad_and_update(cls, label)
        assert grads > 0, f'{label}: no nonzero grads'
        assert delta > 0, f'{label}: params did not change'
    record('gradients/update', True, 'nonzero grads + params updated')
except Exception as exc:
    record('gradients/update', False, str(exc))
    raise


In [None]:
# Section 3 — Seeded Determinism (CPU)
print_section('Seeded Determinism (CPU)')
def deterministic_check(cls, label):
    torch.manual_seed(123)
    model1 = cls(embed_dim=32, num_heads=4, num_slots=8).eval()
    x = torch.randn(2, 6, 32)
    out1, _ = model1(x)
    torch.manual_seed(123)
    model2 = cls(embed_dim=32, num_heads=4, num_slots=8).eval()
    out2, _ = model2(x)
    assert torch.allclose(out1, out2, atol=1e-6), f'{label}: deterministic mismatch'
    torch.manual_seed(124)
    model3 = cls(embed_dim=32, num_heads=4, num_slots=8).eval()
    out3, _ = model3(x)
    assert not torch.allclose(out1, out3, atol=1e-6), f'{label}: seed did not change output'

try:
    for cls, label in variants:
        deterministic_check(cls, label)
    record('determinism', True, 'seeded outputs stable')
except Exception as exc:
    record('determinism', False, str(exc))
    raise


In [None]:
# Section 4 — Slot Count / Head Count Invariants
print_section('Slot/Head Count Invariants')
table = []
try:
    for heads in (1, 2, 4):
        for slots in (4, 8, 16):
            model = AddressedStateAttention(embed_dim=32, num_heads=heads, num_slots=slots)
            x = torch.randn(2, 4, 32)
            out, _ = model(x)
            assert out.shape == x.shape
            table.append((heads, slots, 'ok'))
    print('heads | slots | status')
    for h, s, st in table:
        print(f'{h:>5} | {s:>5} | {st}')
    record('slot/head sweep', True, 'all configs ok')
except Exception as exc:
    record('slot/head sweep', False, str(exc))
    raise


In [None]:
# Section 5 — Masking / Control Surface Tests
print_section('Masking Controls')
try:
    model = AddressedStateAttention(embed_dim=32, num_heads=4, num_slots=8)
    if 'slot_mask' in inspect.signature(model.forward).parameters:
        x = torch.randn(2, 6, 32)
        out_base, _ = model(x)
        mask_half = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0], dtype=torch.bool)
        out_mask, _ = model(x, slot_mask=mask_half)
        diff = (out_base - out_mask).abs().mean().item()
        assert diff > 1e-6, 'mask did not change output'
        mask_zero = torch.zeros(8, dtype=torch.bool)
        out_zero, _ = model(x, slot_mask=mask_zero)
        assert_finite(out_zero, 'masked output')
        max_abs = out_zero.abs().max().item()
        assert max_abs < 1e-5, f'zero-mask output too large: {max_abs}'
        record('masking', True, 'mask affects output + zero-mask safe')
    else:
        record('masking', True, 'masking controls not exposed')
except Exception as exc:
    record('masking', False, str(exc))
    raise


In [None]:
# Section 6 — Routing Override Hook Test
print_section('Routing Override Hook')
try:
    model = AddressedStateAttentionIntervene(embed_dim=32, num_heads=4, num_slots=8)
    if hasattr(model, 'routing_override'):
        x = torch.randn(2, 6, 32)
        out_base, _ = model(x)
        def override_fn(t0, t1, read_logits, read_logits_key, read_logits_content, ctx):
            k = ctx['K']
            return torch.full_like(read_logits, 1.0 / k)
        model.routing_override = override_fn
        out_override, _ = model(x)
        diff = (out_base - out_override).abs().mean().item()
        assert diff > 1e-6, 'override did not change output'
        model.routing_override = None
        record('routing override', True, 'override changes output')
    else:
        record('routing override', True, 'hook not available')
except Exception as exc:
    record('routing override', False, str(exc))
    raise


In [None]:
# Section 7 — Online Variant Specific Check
print_section('Online Variant Check')
try:
    model = AddressedStateAttentionOnline(embed_dim=32, num_heads=4, num_slots=8)
    x = torch.randn(2, 6, 32)
    out, info = model(x, return_info=True)
    assert out.shape == x.shape
    assert info is None or isinstance(info, dict)
    record('online info', True, 'returned output + info')
except Exception as exc:
    record('online info', False, str(exc))
    raise


In [None]:
# Section 8 — Intervention Variant Toggle Behavior
print_section('Intervention Toggle')
try:
    model = AddressedStateAttentionIntervene(embed_dim=32, num_heads=4, num_slots=8)
    x = torch.randn(2, 6, 32)
    out_base, _ = model(x)
    if hasattr(model, '_intv_mode'):
        model._intv_mode = 'orth_gate'
        out_intv, _ = model(x)
        diff = (out_base - out_intv).abs().mean().item()
        assert diff > 1e-6, 'intervention did not change output'
        assert_finite(out_intv, 'intervention output')
        model._intv_mode = 'off'
        record('intervention toggle', True, 'output changes + finite')
    else:
        record('intervention toggle', True, 'controls not exposed')
except Exception as exc:
    record('intervention toggle', False, str(exc))
    raise


In [None]:
# Section 9 — Final Summary
finalize()
