# ASA Colab Quickstart

This notebook clones the repo and runs a tiny training loop for each ASA variant (baseline, online, intervention).


In [None]:
# Clone repo
!git clone https://github.com/digitaldaimyo/ASA.git
%cd ASA


In [None]:
# Install dependencies
!pip -q install -e .


In [None]:
import torch
from asa import AddressedStateAttention, AddressedStateAttentionOnline, AddressedStateAttentionIntervene

def tiny_train(attn_cls, label):
    print(f"\n=== Running {label} ===")
    attn = attn_cls(embed_dim=32, num_heads=4, num_slots=8)
    optim = torch.optim.AdamW(attn.parameters(), lr=1e-3)
    x = torch.randn(2, 8, 32)  # small batch: B=2, seq=8, d=32
    for step in range(5):      # bumped to 5 steps for slightly better visibility
        out, _ = attn(x)
        loss = (out ** 2).mean()  # dummy loss, as before
        optim.zero_grad()
        loss.backward()
        optim.step()
        print(f"  step {step+1:2d} | loss = {loss.item():.4f}")

# Run each variant
tiny_train(AddressedStateAttention, "Baseline (standard ASA)")
tiny_train(AddressedStateAttentionOnline, "Online variant")
tiny_train(AddressedStateAttentionIntervene, "Intervention variant")
