## 29% using a 29M transformer
Steps to reproduce:  
1. Upload this script to google colab or modal
2. (optional) If you want to save checkpoints and results, mount your google drive (colab) / your volume (modal)
3. Click run-all

This script ensures there's no data contamination. 
This produces a `submission.json` file. The `submission.json` file


Notes:
1. The config in this notebook has been tuned for an 80GB A100  
2. Actual results were obtained by running this exact file in 2 phases.  
    - Training on a 40GB A100
    - Take the final checkpoint, and run the inference on an 80GB A100

This will work on smaller GPUs too, but will take longer to train  
For very constrained environments, disable the "do_validate" flag. This avoids checking the validation loss every epoch

In [None]:
# root_folder, mount_folder = "root", "mnt/mithil-arc" # for modal - REPLACE /mithil-arc WITH YOUR VOLUME NAME
root_folder, mount_folder = "content", "content/drive/MyDrive"  # for colab

%cd /$root_folder/
!git clone https://github.com/mvakde/mdlARC.git # `-b <branch_name> --single-branch` if branch
%cd /$root_folder/mdlARC

In [None]:
!python dataset_building_scripts/build_datasets.py --datasets arc1 conceptarc  --splits train eval --with-solutions --cleanup none
!python dataset_building_scripts/augment_dataset_dihedral.py

!rm -rf /$root_folder/mdlARC/run-script.ipynb
!rm -rf /$root_folder/mdlARC/sanitised-env-run-script.ipynb
!rm -rf /$root_folder/mdlARC/ultra-sanitised-env-run-script.ipynb
!rm -rf /$root_folder/mdlARC/dataset_building_scripts
!rm -rf /$root_folder/mdlARC/readme.md
!rm -rf /$root_folder/mdlARC/img

In [None]:
from pathlib import Path
import argparse
import importlib
import sys

PROJECT_ROOT = Path.cwd()
SRC_DIR = PROJECT_ROOT / "src"
if SRC_DIR.exists() and str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

import utils, tinytransformer, train

importlib.reload(utils)  # pick up code changes during iteration
importlib.reload(tinytransformer)
importlib.reload(train)

args = {
    # run config
    "num_workers": 0,
    "device": "cuda",  # 'cuda' | 'mps' | 'cpu'
    "do_validate": False,
    "name": "arc1-cleanenv-30M-vvwide-bs32-101ep-100color-ccdb-18dec0430",  # download file name
    "GPU": "A100-noaugreg",  # just for logging purposes
    # paths - must pass as Path("<path_to_dir>")
    "train_log_file": Path("runs/training_log.txt"),
    "save_path": Path("runs/tiny.pt"),
    "checkpoint_path": None,  # Path("runs/tiny.pt"),  # or None to start from scratch
    "data_path": Path("assets/challenges_dihedral_both.json"),
    # hyperparameters
    "epochs": 101,
    "batch_size": 32,
    "val_batch_size": 300,
    "enable_color_aug_train": True,
    "max_color_augments_train": 100,
    "color_aug_seed": 42,
    "lr": 3e-4,
    "weight_decay": 0.01,
    "grad_clip": 1.0,
    "dropout": 0.1,
    "seed": 42,
    # Model Architecture
    "d_model": 768,  # 128, 256, 512, 768 | 128, 384, 640
    "n_heads": 12,  # 4, 8, 8/16, 12 | 4, 12, 10
    "d_ff": 3072,  # 512, 1024, 2048, 3072 | 512, 1536, 2560
    "n_layers": 4,  # 4, 6, 16, 16 | 24, 28, 24
    # Visibility toggles
    "log_train_strings": False,
    "log_train_limit": 10,
    "log_inference_prompt": False,
    "inference_temperature": None,
    "inference_top_k": None,
}
cfg = argparse.Namespace(**args)

runs_dir = Path("runs")
runs_dir.mkdir(parents=True, exist_ok=True)
with (runs_dir / "config.txt").open("w") as f:
    for k, v in args.items():
        f.write(f"{k}: {v}\n")

model, dataset, dataloader, device, data_path = train.build_model_and_data(cfg)

In [None]:
# Training only

from time import perf_counter

t_start = perf_counter()
train.train_model(cfg,model=model,dataloader=dataloader,dataset=dataset,device=device,data_path=data_path)
t_duration = perf_counter() - t_start

print(f"Training took {t_duration:.2f}s")
with open(Path("runs/timing.txt"), "w") as f:
    f.write(f"Training: {t_duration:.4f} s\n")

In [None]:
# cleaning up memory to run inference
utils.cleanup_memory(globals())


In [None]:
# save data immediately in case eval fails
archive_state = utils.save_run_archive(
    cfg.name, root_folder, mount_folder, globals_dict=globals()
)


In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
SRC_DIR = PROJECT_ROOT / "src"
if SRC_DIR.exists() and str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

import torch
import evaluations
import aaivr
import tinytransformer
import utils
import json
import importlib
from time import perf_counter

# Reload modules to pick up changes
importlib.reload(tinytransformer)
importlib.reload(evaluations)
importlib.reload(aaivr)
importlib.reload(utils)

# Define your paths constants
PATH_BOTH = Path("assets/challenges_dihedral_both.json")

# Config List: (Run Name, Max Color Augments, Dataset Path)
EVAL_CONFIGS = [
    # ("eval_125color_both", 125, PATH_BOTH),
    ("eval_100color_both", 100, PATH_BOTH)
    # ("eval_10color_both", 10, PATH_BOTH),
    # ("eval_0color_both", 0, PATH_BOTH),
    # ("eval_0color_train", 0, PATH_TRAIN) # <--- Uses TRAIN path (No Geom TTA on Test)
]

# Global settings shared across runs
EVAL_BATCH_SIZE = 1300
SPLITS = ["test"]
CHECKPOINT_PATH = Path("runs/tiny.pt")
SOLUTIONS_PRESENT = False
EVAL_TASK_IDS = None  # Set to None to evaluate full dataset, or ["00576224", ...] for specific tasks
LOG_CORRECT_GRIDS = False  # Print the actual grid, IDs, and augmentation indices for fully correct grids


# Helper class for logging to file and console
class TeeLogger(object):
    def __init__(self, filepath):
        self.terminal = sys.stdout
        self.log = open(filepath, "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()


def run_evaluation_pipeline(run_name, max_color_augments, dataset_path, device):
    print(f"\n{'=' * 60}")
    print(f"STARTING PIPELINE: {run_name} (Color Augs: {max_color_augments})")
    print(f"{'=' * 60}\n")

    # 1. Setup Directories
    base_run_dir = Path("runs") / run_name
    base_run_dir.mkdir(parents=True, exist_ok=True)

    eval_log_path = base_run_dir / "eval_log.txt"
    aaivr_log_path = base_run_dir / "aaivr.txt"
    submission_path = base_run_dir / "submission.json"

    # 2. Update Config
    cfg.checkpoint_path = CHECKPOINT_PATH
    cfg.data_path = dataset_path
    cfg.enable_color_aug_eval = max_color_augments > 0
    cfg.max_color_augments_eval = max_color_augments

    # 3. Build/Rebuild Model & Data
    # We rebuild the dataloader every time to handle the different color augmentation settings
    print("Building model and dataloader for config...")

    # Load checkpoint explicitly to pass to build function
    checkpoint = torch.load(
        cfg.checkpoint_path, map_location=device, weights_only=False
    )

    # Check if model exists in global scope to reuse weights, else create
    global model
    if "model" in globals() and model is not None:
        model.load_state_dict(
            checkpoint["model_state"] if "model_state" in checkpoint else checkpoint,
            strict=False,
        )
        model.eval()
        # Rebuild only dataset/loader
        _, dataset, dataloader, device, _ = train.build_model_and_data(
            cfg, checkpoint=checkpoint
        )
    else:
        model, dataset, dataloader, device, _ = train.build_model_and_data(cfg)

    # 4. Run Inference (Logic from old Cell 3)
    def log_eval(msg):
        print(msg)
        with open(eval_log_path, "a") as f:
            f.write(msg + "\n")

    color_mappings_eval = None
    color_apply_fn = None
    if cfg.enable_color_aug_eval and cfg.max_color_augments_eval > 0:
        color_seed = cfg.color_aug_seed or cfg.seed
        color_mappings_eval = utils.generate_color_mapping_tensors(
            cfg.max_color_augments_eval, color_seed
        )
        color_apply_fn = lambda split: True

    evaluation = evaluations.evaluate_model_on_dataset(
        model=model,
        dataset=dataset,
        device=device,
        batch_size=EVAL_BATCH_SIZE,
        log_prompts=args["log_inference_prompt"],
        temperature=args["inference_temperature"],
        top_k=args["inference_top_k"],
        splits=SPLITS,
        color_mappings=color_mappings_eval,
        color_apply_fn=color_apply_fn,
        task_ids=EVAL_TASK_IDS,
        include_targets=SOLUTIONS_PRESENT,
    )

    # Log Inference Stats
    log_eval(f"\n-- {cfg.epochs}ep {max_color_augments}color --\n")
    for split in SPLITS:
        summary = evaluation.get(split, {}).get("summary", {})
        total = summary.get("total_sequences", 0)
        shape_ok = summary.get("num_shape_correct", 0)
        fully_correct = summary.get("num_fully_correct", 0)
        avg_pixel_acc = summary.get("avg_pixel_accuracy", 0.0)

        log_eval(
            f"Split: {split} | Seq: {total} | Shape OK: {shape_ok} | Fully Correct: {fully_correct} | Pixel Acc: {avg_pixel_acc:.4f}"
        )

        if LOG_CORRECT_GRIDS and fully_correct > 0:
            log_eval(f"  [Correct Grids Details for {split}]")

            # Determine if THIS split has dihedral augmentations
            # Train is augmented if "dihedral" is anywhere in the name
            # Test is augmented ONLY if "dihedral_both" is in the name
            is_dihedral_split = (split == "train" and "dihedral" in data_path_str) or (
                split == "test" and "dihedral_both" in data_path_str
            )

            correct_results = summary.get("fully_correct_results", [])
            for res in correct_results:
                raw_idx = res.get("pair_index", 0)

                # Decode indices based on split properties
                if is_dihedral_split:
                    pair_id = raw_idx // 8
                    dihedral_id = raw_idx % 8
                else:
                    pair_id = raw_idx
                    dihedral_id = 0

                color_id = res.get("color_permutation_index", 0)
                grid = res.get("output_grid", [])

                log_eval(
                    f"    T:{res.get('task_id')} | Pair:{pair_id} | Dihedral:{dihedral_id} | Color:{color_id} -> {grid}"
                )

    # 5. Run AAIVR (Logic from old Cell 4)
    print(f"Running AAIVR for {run_name}...")

    # Redirect stdout for AAIVR logging
    if hasattr(sys.stdout, "log"):
        sys.stdout = sys.stdout.terminal  # Reset if needed
    sys.stdout = TeeLogger(str(aaivr_log_path))

    try:
        test_results = evaluation.get("test", {}).get("results", [])
        dataset_has_dihedral_augments = "dihedral_both" in str(cfg.data_path)

        aaivr_results = []
        if test_results:
            aaivr_results = aaivr.run_aaivr_on_results(
                test_results,
                is_dihedral_augmented=dataset_has_dihedral_augments,
                color_aug_seed=cfg.color_aug_seed,
                max_color_augments=cfg.max_color_augments_eval,
            )

            # Print Stats (will go to console + aaivr.txt)
            aaivr.summarize_aaivr_pass_at_k(aaivr_results)
            if aaivr_results:
                tasks_map = {}
                for res in aaivr_results:
                    if res.task_id not in tasks_map:
                        tasks_map[res.task_id] = []
                    tasks_map[res.task_id].append(res)

                arc_score = 0.0
                total_tasks = len(tasks_map)

                for t_id, pairs in tasks_map.items():
                    n_pairs = len(pairs)
                    if n_pairs > 0:
                        n_solved = sum(1 for p in pairs if p.pass_at_k)
                        arc_score += n_solved / n_pairs

                max_score = total_tasks
                pct = (arc_score / max_score * 100) if max_score > 0 else 0.0
                print(
                    f"Official ARC style scoring: {arc_score:.2f}/{max_score} ({pct:.2f}%)"
                )
        else:
            print("No test results for AAIVR.")

    finally:
        # Always restore stdout
        if hasattr(sys.stdout, "terminal"):
            sys.stdout.close()
            sys.stdout = sys.stdout.terminal

    # 6. Generate Submission (Logic from old Cell 5)
    print(f"Generating submission.json for {run_name}...")
    submission_data = {}
    temp_grouping = {}

    if aaivr_results:
        for item in aaivr_results:
            t_id = item.task_id
            p_idx = item.original_pair_index
            if t_id not in temp_grouping:
                temp_grouping[t_id] = {}

            top_grids = item.selected_outputs[:2]
            if not top_grids:
                top_grids = [[[0]]]  # Fallback

            pair_dict = {
                "attempt_1": top_grids[0],
                "attempt_2": top_grids[1] if len(top_grids) > 1 else top_grids[0],
            }
            temp_grouping[t_id][p_idx] = pair_dict

        for t_id, pairs_map in temp_grouping.items():
            sorted_indices = sorted(pairs_map.keys())
            submission_data[t_id] = [pairs_map[idx] for idx in sorted_indices]

    with open(submission_path, "w") as f:
        json.dump(submission_data, f)

    print(f"Finished {run_name}. Submission saved to {submission_path}")


# --- Execute the Loop (Modified with Timing) ---
timing_path = Path("runs/timing.txt")

for name, aug_count, d_path in EVAL_CONFIGS:  # <--- Unpack 3 items
    t_start = perf_counter()

    run_evaluation_pipeline(name, aug_count, d_path, device)

    t_duration = perf_counter() - t_start
    print(f"Run {name} took {t_duration:.2f}s")

    with open(timing_path, "a") as f:
        f.write(f"Evaluation {name}: {t_duration:.4f} s\n")

print("\nAll evaluation runs completed.")

In [None]:
# refresh Drive zip
archive_state = utils.update_run_archive(
    cfg.name, root_folder, mount_folder, globals_dict=globals()
)


In [None]:
# visualisation
EVAL_SUB_FOLDER = "eval_100color_both"
VIS_MODE = "!"  # "!" = compare vs solutions, "submission" = attempts-only
utils.visualize_eval_submissions(
    EVAL_SUB_FOLDER, mode=VIS_MODE, solutions_file="assets/solutions.json"
)
