In [None]:
%cd /content/mdlARC/
from pathlib import Path
import argparse
import importlib
import utils, train

importlib.reload(utils)  # pick up code changes during iteration
importlib.reload(train)

args = {
    # run config
    "num_workers": 0,
    "device": "cuda",  # 'cuda' | 'mps' | 'cpu'
    # paths - must pass as Path("<path_to_dir>")
    "save_path": Path("runs/tiny.pt"),
    "checkpoint_path": None,  # Path("runs/tiny.pt"),  # or None to start from scratch
    # "data_path": Path("assets/script-tests/grouped-tasks-00d62c1b/challenges.json" ),
    "data_path": Path("assets/script-tests/grouped-tasks/challenges.json"),
    # "data_path": Path("assets/ARC-1/grouped-tasks/training/challenges.json"),
    # "data_path": Path("assets/ARC-2/grouped-tasks/training/challenges.json"),
    # hyperparameters
    "epochs": 5,
    "batch_size": 110,
    "val_batch_size": 500,
    "lr": 3e-4,
    "weight_decay": 0.01,
    "grad_clip": 1.0,
    "seed": 42,
    # Visibility toggles
    "log_train_strings": False,
    "log_train_limit": 10,
    "log_inference_prompt": False,
}
cfg = argparse.Namespace(**args)
model, dataset, dataloader, device, data_path = train.build_model_and_data(cfg)

In [None]:
# Training only
train.train_model(
    cfg,
    model=model,
    dataloader=dataloader,
    dataset=dataset,
    device=device,
    data_path=data_path,
)

In [None]:
# inference + visualisation check
import inference
from utils import plot_grids, split_grids_from_tokens, tokens_to_string

importlib.reload(utils)
importlib.reload(inference)

task_ids_list = ["00d62c1b", "e0fb7511", "00576224", "3aa6fb7a"]  # always pass as list
selected_split = "test"
# selected_split = "train"
pair_idx = 0
visualise = True

results = inference.run_batched_inference(
    model=model,
    dataset=dataset,
    task_ids=task_ids_list,
    device=device,
    split=selected_split,
    pair_index=pair_idx,
    include_targets=True,
)

if not results:
    print("No inference results were produced.")
for res in results:
    print(f"\nTask {res['task_id']} pair {res['pair_index']} ({selected_split})")
    print("Prompt tokens:", tokens_to_string(res["prompt_tokens"]))
    print("Generated output tokens:", tokens_to_string(res["output_tokens"]))
    if res.get("target_output_tokens"):
        print("Target output tokens:", tokens_to_string(res["target_output_tokens"]))
    print("Predicted grid:")
    for row in res["output_grid"]:
        print(row)
    if res.get("target_grid"):
        print("Target grid:")
        for row in res["target_grid"]:
            print(row)
    if visualise:
        prompt_grids = split_grids_from_tokens(res["prompt_tokens"])
        input_grid = prompt_grids[0] if prompt_grids else []
        to_plot = [input_grid, res["output_grid"]]
        if res.get("target_grid"):
            to_plot.append(res["target_grid"])
        plot_grids(
            to_plot,
            title=f"{res['task_id']} pair {res['pair_index']} ({selected_split})",
        )

In [None]:
# Full dataset evaluation
import inference

importlib.reload(inference)

EVAL_BATCH_SIZE = 1300

evaluation = inference.evaluate_model_on_dataset(
    model=model,
    dataset=dataset,
    device=device,
    batch_size=EVAL_BATCH_SIZE,
    log_prompts=args["log_inference_prompt"],
)

for split in ("train", "test"):
    summary = evaluation.get(split, {}).get("summary", {})
    total = summary.get("total_sequences", 0)
    shape_ok = summary.get("num_shape_correct", 0)
    avg_pixel_acc = summary.get("avg_pixel_accuracy", 0.0)
    fully_correct = summary.get("num_fully_correct", 0)

    print(f"\nSplit: {split}")
    print(f"  sequences evaluated: {total}")
    print(f"  correct output grid shapes: {shape_ok} / {total}")
    if shape_ok > 0:
        print(f"  avg pixel accuracy (shape-correct only): {avg_pixel_acc:.4f}")
    else:
        print("  avg pixel accuracy (shape-correct only): n/a")
    print(f"  fully correct output grids: {fully_correct} / {total}")

    if split == "test":
        correct_outputs = summary.get("fully_correct_results", [])
        print("  fully correct test outputs (task_id, pair_index, grid):")
        if not correct_outputs:
            print("    (none)")
        for res in correct_outputs:
            grid = res.get("output_grid", [])
            print(f"    - {res.get('task_id')} pair {res.get('pair_index')}: {grid}")


# Large scale training run (periodic checkpointing and evaluation)

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
# config
%cd /content/mdlARC/
from pathlib import Path
import argparse
import importlib
import utils, train

importlib.reload(utils)  # pick up code changes during iteration
importlib.reload(train)

args = {
    # run config
    "num_workers": 0,
    "device": "cuda",  # 'cuda' | 'mps' | 'cpu'
    # paths - must pass as Path("<path_to_dir>")
    "train_log_file": Path("runs/train_log.txt"),
    # "data_path": Path("assets/script-tests/grouped-tasks-00d62c1b/challenges.json" ),
    # "data_path": Path("assets/script-tests/grouped-tasks/challenges.json"),
    # "data_path": Path("assets/ARC-1/grouped-tasks/training/challenges.json"),
    "data_path": Path("assets/ARC-2/grouped-tasks/training/challenges.json"),
    # hyperparameters
    "batch_size": 140,
    "val_batch_size": 500,
    "lr": 3e-4,
    "weight_decay": 0.01,
    "grad_clip": 1.0,
    "seed": 42,
    # Visibility toggles
    "log_train_strings": False,
    "log_train_limit": 10,
    "log_inference_prompt": False,
}
cfg = argparse.Namespace(**args)


In [None]:
# Training with periodic checkpointing

cfg.epochs = 500
n_cycles = 10

# Make base 0 if training from scratch. Otherwise, last checkpoint number
cfg.base = 600

for i in range(n_cycles):
    cfg.save_path = Path(f"runs/tiny-{(i + 1) * cfg.epochs + cfg.base}.pt")

    if cfg.base == 0 and i == 0:
        cfg.checkpoint_path = None
        reuse_dataset = None
    else:
        cfg.checkpoint_path = Path(f"runs/tiny-{i * cfg.epochs + cfg.base}.pt")
        try:
            reuse_dataset = dataset
        except NameError:
            reuse_dataset = None  # if `dataset` was deleted from memory

    model, dataset, dataloader, device, data_path = train.build_model_and_data(
        cfg, reuse_dataset=reuse_dataset
    )
    train.train_model(
        cfg,
        model=model,
        dataloader=dataloader,
        dataset=dataset,
        device=device,
        data_path=data_path,
    )

In [None]:
# clearing gpu memory before running inference
import gc
import torch

gc.collect()
torch.cuda.empty_cache()

In [None]:
# evaluate all the checkpoints
import json
import inference

importlib.reload(inference)

EVAL_BATCH_SIZE = 1000
eval_log_path = Path("runs/eval_log.txt")

# only for debugging. keep commented out
# cfg.base = 200
# cfg.epochs = 400
# n_cycles = 2

try:
    reuse_dataset = dataset
except NameError:
    reuse_dataset = None

for i in range(n_cycles):
    cfg.checkpoint_path = Path(f"runs/tiny-{(i + 1) * cfg.epochs + cfg.base}.pt")
    model, dataset, dataloader, device, data_path = train.build_model_and_data(
        cfg, reuse_dataset=reuse_dataset
    )
    evaluation = inference.evaluate_model_on_dataset(
        model=model,
        dataset=dataset,
        device=device,
        batch_size=EVAL_BATCH_SIZE,
        log_prompts=cfg.log_inference_prompt,
    )

    formatted_eval = inference.group_eval_sequences_by_task(evaluation)

    json_path = cfg.checkpoint_path.with_suffix(".json")
    with open(json_path, "w") as f:
        json.dump(formatted_eval, f, indent=2)
    print(f"Saved grouped task sequences to: {json_path}")

    with open(eval_log_path, "a") as f:
        f.write(f"--- Checkpoint: {(i + 1) * cfg.epochs + cfg.base} ---\n")
        for split_name, data in evaluation.items():
            summary = data["summary"]
            log_line = (
                f"Split: {split_name:<6} | "
                f"Total: {summary['total_sequences']:<4} | "
                f"Shape Correct: {summary['num_shape_correct']:<4} | "
                f"Fully Correct: {summary['num_fully_correct']:<4} | "
                f"Pixel Acc: {summary['avg_pixel_accuracy']:.4f}"
            )
            print(log_line)
            f.write(log_line + "\n")
            if split_name == "test":
                correct_outputs = summary.get("fully_correct_results", [])
                log_line = "  fully correct test outputs (task_id, pair_index, grid):"
                if not correct_outputs:
                    log_line = log_line + "\n    (none)"
                for res in correct_outputs:
                    grid = res.get("output_grid", [])
                    log_line = log_line + (
                        f"\n    - {res.get('task_id')} pair {res.get('pair_index')}: {grid}"
                    )
                print(log_line)
                f.write(log_line + "\n")

        f.write("\n\n")
        print("\n")

In [None]:
import shutil
from datetime import datetime

timestamp = datetime.now().strftime("%d%m%y-%H%M%S")

src_dir = "/content/mdlARC/runs"  # folder to back up
zip_base = "/content/mdlARC/runs_compressed"  # zip will become this + '.zip'
dst_zip = f"/content/drive/MyDrive/run_{timestamp}.zip"

# Create zip in /content
shutil.make_archive(zip_base, "zip", src_dir)

# Copy zip into Drive
shutil.copy2(zip_base + ".zip", dst_zip)