In [None]:
%cd /content/mdlARC/
from pathlib import Path
import argparse
import importlib
import utils, train

# Reload order matters: utils first, then train (train imports utils)
importlib.reload(utils)
importlib.reload(train)  # pick up code changes during iteration

# Editable arguments
args = {
    # "data_path": Path("assets/script-tests/grouped-tasks-00d62c1b/challenges.json" ),
    # "data_path": Path("assets/script-tests/grouped-tasks/challenges.json"),
    # "data_path": Path("assets/ARC-1/grouped-tasks/training/challenges.json"),
    "data_path": Path("assets/ARC-2/grouped-tasks/training/challenges.json"),
    "batch_size": 110,
    "epochs": 200,
    "lr": 3e-3,
    "weight_decay": 0.01,
    "grad_clip": 1.0,
    "num_workers": 0,
    "device": "cuda",  # 'cuda' | 'mps' | 'cpu'
    "seed": 42,
    "save_path": Path("runs/tiny.pt"),
    "checkpoint_path": None,  # Path('runs/tiny.pt') to load else None
    "eval_only": False,
    "inference_task_id": "e0fb7511",  # "3aa6fb7a",  "00d62c1b", "e0fb7511" '00576224' to run single inference
    "inference_pair_index": 0,
    # Visibility toggles
    "log_train_strings": False,
    "log_train_limit": 10,
    "log_inference_prompt": True,
    "log_eval_strings": True,
    "log_eval_limit": 10,
    "plot_inference_grids": True,
}


def make_namespace(d):
    # Ensure Path types for known path-like keys
    for k in ["data_path", "save_path", "checkpoint_path"]:
        if d.get(k) is not None and not isinstance(d[k], Path):
            d[k] = Path(d[k])
    return argparse.Namespace(**d)


In [None]:
# Training only
cfg = dict(args)
cfg["eval_only"] = False
ns = make_namespace(cfg)
model, dataset, dataloader, device, data_path = train.build_model_and_data(ns)
train.train_model(
    ns,
    model=model,
    dataloader=dataloader,
    dataset=dataset,
    device=device,
    data_path=data_path,
)


In [None]:
# Optional: Single-example inference by task id and pair index
# Set args['inference_task_id'] above (e.g., '00576224'), then run this cell.
cfg = dict(args)
cfg["eval_only"] = True
ns = make_namespace(cfg)
model, dataset, dataloader, device, data_path = train.build_model_and_data(ns)
assert cfg["inference_task_id"] is not None, "Set inference_task_id in args first."
cfg["inference_task_id"] = (
    "0520fde7"  # "3aa6fb7a",  "00d62c1b", "e0fb7511" '00576224', "0520fde7" to run single inference
)

train.run_inference(
    model=model,
    dataset=dataset,
    task_id=cfg["inference_task_id"],
    pair_index=cfg["inference_pair_index"],
    device=device,
    log_prompt=cfg["log_inference_prompt"],
    plot_grids_flag=cfg["plot_inference_grids"],
)


In [None]:
# Test generation on a train pair (input -> predicted output), compare to ground-truth
%cd /content/mdlARC/
from pathlib import Path
import torch

# Local imports from this repo
from train import load_checkpoint, resolve_device, greedy_generate
from tinytransformer import TinyTransformer, TinyTransformerConfig
from utils import (
    ARCExampleDataset,
    MAX_SEQ_LEN,
    IO_SEPARATOR_TOKEN_ID,
    extract_output_tokens,
    tokens_to_grid,
    split_grids_from_tokens,
    tokens_to_string,
    plot_grids,
)

# ---- Configuration (edit these) ----
CHECKPOINT_PATH = "runs/tiny.pt"  # e.g., "runs/ckpt.pt"; or None to test random weights
DATA_PATH = "assets/ARC-1/grouped-tasks/training/challenges.json"  # or "assets/ARC-2/grouped-tasks/training/challenges.json"
DEVICE = "cuda"  # "cuda", "mps", or "cpu" (auto-fallback if unavailable)
TASK_ID = "00d62c1b"  # e.g., "00d62c1b"; if None, picks the first available train pair
PAIR_INDEX = 0
PLOT = True
# ------------------------------------

# Load checkpoint (optional)
ckpt = load_checkpoint(Path(CHECKPOINT_PATH)) if CHECKPOINT_PATH else None
device = resolve_device(DEVICE)

# Keep dataset tasks aligned with checkpoint (so example_embedding size matches)
task_whitelist = ckpt.get("task_ids") if ckpt and "task_ids" in ckpt else None

# Build dataset (includes outputs so we can compare)
dataset = ARCExampleDataset(
    json_path=Path(DATA_PATH),
    splits=("train", "test"),
    include_outputs=True,
    max_seq_len=MAX_SEQ_LEN,
    task_whitelist=task_whitelist,
)

# Build model config (prefer the one saved in checkpoint)
if ckpt and "config" in ckpt:
    cfg = TinyTransformerConfig(**ckpt["config"])
else:
    cfg = TinyTransformerConfig(num_examples=dataset.num_examples)

model = TinyTransformer(cfg).to(device)
if ckpt:
    model.load_state_dict(ckpt["model_state"], strict=False)

# Select a train example
if TASK_ID is None:
    ex = next(dataset.iter_examples(split="train", has_output=True))
else:
    ex = next(
        e
        for e in dataset.iter_examples(split="train", has_output=True)
        if e.task_id == TASK_ID and e.pair_index == PAIR_INDEX
    )

# Build prompt: input tokens up to and including the separator
seq = ex.tokens.tolist()
try:
    sep_ix = seq.index(IO_SEPARATOR_TOKEN_ID)
except ValueError:
    raise RuntimeError("Selected train example is missing <input_output_separator>.")
prompt_tokens = seq[: sep_ix + 1]

# Generate
generated = greedy_generate(
    model=model,
    prompt_tokens=torch.tensor(prompt_tokens, dtype=torch.long),
    example_id=ex.example_id,
    device=device,
)

# Decode prediction and reference
gen_tokens_after_sep = extract_output_tokens(generated.tolist())
predicted_grid = tokens_to_grid(gen_tokens_after_sep)
all_grids = split_grids_from_tokens(ex.tokens.tolist())
reference_grid = all_grids[1] if len(all_grids) > 1 else []

# Log results
print(f"Task: {ex.task_id} | Pair: {ex.pair_index}")
print("\nPrompt (string):")
print(tokens_to_string(prompt_tokens))
print("\nGenerated output (string):")
print(tokens_to_string(gen_tokens_after_sep))

print("\nPredicted grid:")
if predicted_grid:
    for row in predicted_grid:
        print(row)
else:
    print("<empty>")

print("\nReference grid:")
if reference_grid:
    for row in reference_grid:
        print(row)
else:
    print("<empty>")

print("\nExact grid match:", predicted_grid == reference_grid)

# Optional plotting: input grid vs predicted output
if PLOT:
    try:
        input_grid = all_grids[0] if all_grids else []
        to_plot = [input_grid, predicted_grid]
        plot_grids(to_plot, title=f"task {ex.task_id} pair {ex.pair_index}")
    except Exception as e:
        print(f"Plotting failed: {e}")


In [None]:
# Eval-only across test pairs (requires a checkpoint or weights already in memory)
cfg = dict(args)
cfg["eval_only"] = True
ns = make_namespace(cfg)
model, dataset, dataloader, device, data_path = train.build_model_and_data(ns)
train.evaluate_model(
    ns, model=model, dataset=dataset, device=device, data_path=data_path
)


In [None]:
# Train + Eval combo (convenience)
cfg = dict(args)
cfg["eval_only"] = False
ns = make_namespace(cfg)
model, dataset, dataloader, device, data_path = train.build_model_and_data(ns)
train.train_model(
    ns,
    model=model,
    dataloader=dataloader,
    dataset=dataset,
    device=device,
    data_path=data_path,
)
train.evaluate_model(
    ns, model=model, dataset=dataset, device=device, data_path=data_path
)
