In [5]:
import os
import sys
import json
from pathlib import Path

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ----------------------------
# Load all JSONs (recursive)
# ----------------------------
def load_jsons_from_folder(dir_path):
    """
    Read every .json file under dir_path (recursively) and return a dict
    keyed by the file's relative path (without the .json extension).
    """
    root = Path(dir_path).expanduser().resolve()
    files = sorted(p for p in root.rglob("*.json") if p.is_file())

    if not files:
        raise FileNotFoundError(f"No .json files found under: {root}")

    data = {}
    for p in files:
        key = str(p.relative_to(root).with_suffix(""))  # e.g. "subdir/file"
        try:
            with p.open("r", encoding="utf-8") as fh:
                data[key] = json.load(fh)
        except Exception as e:
            print(f"Failed to read {p}: {e}")

    if not data:
        raise FileNotFoundError(f"Unable to load any .json files under: {root}")

    return data


# ----------------------------
# Preprocess helpers
# ----------------------------
def _add_one_to_all_values_in_place(data):
    """
    Adds +1 to every scalar value in each input/output grid across all samples.
    Done BEFORE padding so pad_value=0 remains 0.
    """
    for sample in data.values():
        for split in ["train", "test"]:
            for pairs in sample.get(split, []):
                # input grid
                r = 0
                while r < len(pairs["input"]):
                    c = 0
                    row = pairs["input"][r]
                    while c < len(row):
                        row[c] = row[c] + 1
                        c += 1
                    r += 1
                # output grid
                r = 0
                while r < len(pairs["output"]):
                    c = 0
                    row = pairs["output"][r]
                    while c < len(row):
                        row[c] = row[c] + 1
                        c += 1
                    r += 1


def get_metrics(data):
    metric_dict = {
        "max_train_len": 0,
        "max_test_len": 0,
        "max_train_input_height": 0,
        "max_test_input_height": 0,
        "max_train_output_height": 0,
        "max_test_output_height": 0,
        "max_train_input_width": 0,
        "max_test_input_width": 0,
        "max_train_output_width": 0,
        "max_test_output_width": 0
    }

    for sample in data.values():
        if (len(sample['train']) > metric_dict['max_train_len']):
            metric_dict['max_train_len'] = len(sample['train'])
        if (len(sample['test']) > metric_dict['max_test_len']):
            metric_dict['max_test_len'] = len(sample['test'])
        for pairs in sample['train']:
            if (len(pairs['input']) > metric_dict['max_train_input_height']):
                metric_dict['max_train_input_height'] = len(pairs['input'])
            if (len(pairs['output']) > metric_dict['max_train_output_height']):
                metric_dict['max_train_output_height'] = len(pairs['output'])
            for inp in pairs['input']:
                if (len(inp) > metric_dict['max_train_input_width']):
                    metric_dict['max_train_input_width'] = len(inp)
            for output in pairs['output']:
                if (len(output) > metric_dict['max_train_output_width']):
                    metric_dict['max_train_output_width'] = len(output)
        for pairs in sample['test']:
            if (len(pairs['input']) > metric_dict['max_test_input_height']):
                metric_dict['max_test_input_height'] = len(pairs['input'])
            if (len(pairs['output']) > metric_dict['max_test_output_height']):
                metric_dict['max_test_output_height'] = len(pairs['output'])
            for inp in pairs['input']:
                if (len(inp) > metric_dict['max_test_input_width']):
                    metric_dict['max_test_input_width'] = len(inp)
            for output in pairs['output']:
                if (len(output) > metric_dict['max_test_output_width']):
                    metric_dict['max_test_output_width'] = len(output)
    return metric_dict


def pad_data(data, metric_dict=None, pad_value=0):
    """
    Pads the ENTIRE dataset so that:
      • all TRAIN pairs are square-padded to the same dataset-wide size, and
      • all TEST  pairs are square-padded to the same dataset-wide size.    # CHANGED
    If metric_dict is None, it will be computed from the data.               # NEW
    """
    # ----- compute global (dataset-wide) sizes -----                         # NEW
    if metric_dict is None:
        metric_dict = get_metrics(data)

    max_train_size = max(
        metric_dict["max_train_input_height"],
        metric_dict["max_train_input_width"],
        metric_dict["max_train_output_height"],
        metric_dict["max_train_output_width"]
    )
    max_test_size = max(
        metric_dict["max_test_input_height"],
        metric_dict["max_test_input_width"],
        metric_dict["max_test_output_height"],
        metric_dict["max_test_output_width"]
    )

    # ----- pad EVERY sample to the global split sizes -----                  # CHANGED
    for sample in data.values():
        # TRAIN -> global train size
        for pairs in sample.get('train', []):
            # input
            while len(pairs['input']) < max_train_size:
                pairs['input'].append([pad_value] * max_train_size)
            for inp in pairs['input']:
                while len(inp) < max_train_size:
                    inp.append(pad_value)
            # output
            while len(pairs['output']) < max_train_size:
                pairs['output'].append([pad_value] * max_train_size)
            for outp in pairs['output']:
                while len(outp) < max_train_size:
                    outp.append(pad_value)

        # TEST -> global test size
        for pairs in sample.get('test', []):
            # input
            while len(pairs['input']) < max_test_size:
                pairs['input'].append([pad_value] * max_test_size)
            for inp in pairs['input']:
                while len(inp) < max_test_size:
                    inp.append(pad_value)
            # output
            while len(pairs['output']) < max_test_size:
                pairs['output'].append([pad_value] * max_test_size)
            for outp in pairs['output']:
                while len(outp) < max_test_size:
                    outp.append(pad_value)

    return data


def _infer_original_size_from_padded(grid, pad_value=0):
    h = 0
    w = 0
    r = 0
    while r < len(grid):
        row = grid[r]
        any_nonpad = False
        last_nonpad = -1
        c = 0
        while c < len(row):
            if row[c] != pad_value:
                any_nonpad = True
                last_nonpad = c
            c += 1
        if any_nonpad:
            if (r + 1) > h:
                h = r + 1
            if (last_nonpad + 1) > w:
                w = last_nonpad + 1
        r += 1
    return (h, w)


def build_sample_level_dataset(data, pad_value=0):
    """
    Build a list of per-sample records.
    NEW: also stores per-pair masks: 1 where value != pad_value, else 0.
    """
    dataset = []
    for sample_name, sample in data.items():
        # containers
        train_pairs = []
        test_pairs = []

        # track original (unpadded) sizes per split
        train_max_h = 0
        train_max_w = 0
        test_max_h = 0
        test_max_w = 0

        # ----- TRAIN -----
        idx = 0
        for pairs in sample['train']:
            inp_grid = pairs['input']
            out_grid = pairs['output']

            # original sizes (prefer stored, else infer)
            if ('orig_input_size' in pairs):
                in_h, in_w = pairs['orig_input_size']
            else:
                in_h, in_w = _infer_original_size_from_padded(inp_grid, pad_value)
            if ('orig_output_size' in pairs):
                out_h, out_w = pairs['orig_output_size']
            else:
                out_h, out_w = _infer_original_size_from_padded(out_grid, pad_value)

            # update split-wide original size (max over inputs/outputs)
            if in_h > train_max_h: train_max_h = in_h
            if out_h > train_max_h: train_max_h = out_h
            if in_w > train_max_w: train_max_w = in_w
            if out_w > train_max_w: train_max_w = out_w

            # tensors
            inp_tensor = torch.tensor(inp_grid).long()
            out_tensor = torch.tensor(out_grid).long()

            # NEW: masks (1 for non-pad, 0 for pad)
            inp_mask = (inp_tensor != pad_value).long()
            out_mask = (out_tensor != pad_value).long()

            # store pair
            train_pairs.append({
                "input": inp_tensor,
                "output": out_tensor,
                "input_mask": inp_mask,
                "output_mask": out_mask
            })
            idx += 1

        # ----- TEST -----
        idx = 0
        for pairs in sample['test']:
            inp_grid = pairs['input']
            out_grid = pairs['output']

            if ('orig_input_size' in pairs):
                in_h, in_w = pairs['orig_input_size']
            else:
                in_h, in_w = _infer_original_size_from_padded(inp_grid, pad_value)
            if ('orig_output_size' in pairs):
                out_h, out_w = pairs['orig_output_size']
            else:
                out_h, out_w = _infer_original_size_from_padded(out_grid, pad_value)

            if in_h > test_max_h: test_max_h = in_h
            if out_h > test_max_h: test_max_h = out_h
            if in_w > test_max_w: test_max_w = in_w
            if out_w > test_max_w: test_max_w = out_w

            inp_tensor = torch.tensor(inp_grid).long()
            out_tensor = torch.tensor(out_grid).long()

            # NEW: masks (1 for non-pad, 0 for pad)
            inp_mask = (inp_tensor != pad_value).long()
            out_mask = (out_tensor != pad_value).long()

            test_pairs.append({
                "input": inp_tensor,
                "output": out_tensor,
                "input_mask": inp_mask,
                "output_mask": out_mask
            })
            idx += 1

        # assemble sample-level record
        item = {
            "id": str(sample_name),
            "train_pairs": train_pairs,
            "test_pairs": test_pairs,
            "train_original_size": (train_max_h, train_max_w),
            "test_original_size": (test_max_h, test_max_w)
        }
        dataset.append(item)

    return dataset


# ----------------------------
# Torch dataset
# ----------------------------
class ARCSampleDataset(Dataset):
    def __init__(self, sample_list):
        self.data = sample_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # stack per-sample pairs into tensors
        train_inputs = torch.stack([p["input"] for p in sample["train_pairs"]])      # [num_train, H, W]
        train_outputs = torch.stack([p["output"] for p in sample["train_pairs"]])    # [num_train, H, W]
        test_inputs = torch.stack([p["input"] for p in sample["test_pairs"]])        # [num_test, H, W]
        test_outputs = torch.stack([p["output"] for p in sample["test_pairs"]])      # [num_test, H, W]

        # masks
        train_input_masks = torch.stack([p["input_mask"] for p in sample["train_pairs"]])
        train_output_masks = torch.stack([p["output_mask"] for p in sample["train_pairs"]])
        test_input_masks  = torch.stack([p["input_mask"] for p in sample["test_pairs"]])
        test_output_masks = torch.stack([p["output_mask"] for p in sample["test_pairs"]])

        return {
            "id": sample["id"],
            "train_inputs": train_inputs,
            "train_outputs": train_outputs,
            "test_inputs": test_inputs,
            "test_outputs": test_outputs,
            "train_input_masks": train_input_masks,
            "train_output_masks": train_output_masks,
            "test_input_masks": test_input_masks,
            "test_output_masks": test_output_masks,
            "train_original_size": torch.tensor(sample["train_original_size"], dtype=torch.long),
            "test_original_size": torch.tensor(sample["test_original_size"], dtype=torch.long)
        }


def arc_collate_fn_bs1(batch):
    # batch size is guaranteed to be 1; return the single dict unchanged
    return batch[0]


# ----------------------------
# NEW: Small pretty-printer for grids (cropped)
# ----------------------------
def _pretty_grid(tensor, max_rows=6, max_cols=10):  # NEW
    arr = tensor.tolist()
    lines = []
    r = 0
    while r < min(len(arr), max_rows):
        row = arr[r]
        row_disp = row[:max_cols]
        row_txt = str(row_disp) + (" ... " if len(row) > max_cols else "")
        lines.append(row_txt)
        r += 1
    if len(arr) > max_rows:
        lines.append("...")
    return "\n".join(lines)


# ----------------------------
# NEW: Data module wrapper
# ----------------------------
class ARCDataModule:
    """
    Simple wrapper to produce a DataLoader from your folder.
    Usage:
        dm = ARCDataModule("~/path/to/training").prepare()
        loader = dm.get_loader()
        for batch in loader: ...
    """
    def __init__(
        self,
        dir_path,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        pin_memory=False,
        pad_value=0,
    ):
        self.dir_path = Path(dir_path).expanduser().resolve()
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.pad_value = pad_value

        self.dataset = None
        self._loader = None
        self.metrics = None  # NEW

    def prepare(self):
        # load + preprocess
        data = load_jsons_from_folder(self.dir_path)
        _add_one_to_all_values_in_place(data)

        # compute dataset-wide metrics + pad globally                        # CHANGED
        self.metrics = get_metrics(data)                                     # NEW
        padded = pad_data(data, metric_dict=self.metrics, pad_value=self.pad_value)

        sample_list = build_sample_level_dataset(padded, pad_value=self.pad_value)

        # build dataset + loader
        self.dataset = ARCSampleDataset(sample_list=sample_list)
        self._loader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            collate_fn=arc_collate_fn_bs1,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )
        return self  # allow chaining

    def get_loader(self):
        if self._loader is None:
            self.prepare()
        return self._loader

    # convenience so the module itself is iterable
    def __iter__(self):
        return iter(self.get_loader())

    def __len__(self):
        return len(self.dataset) if self.dataset is not None else 0


# ----------------------------
# Main
# ----------------------------
if __name__ == "__main__":
    # Point to your local folder named "training"
    folder_path = Path("~/ARC-AGI-Model/src/data_pipeline/ARC_data/data/training")

    data_module = ARCDataModule(
        dir_path=folder_path,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        pin_memory=False,
        pad_value=0,
    ).prepare()

    arc_loader = data_module.get_loader()

    # Expected global sizes (dataset-wide)                                    # NEW
    M = data_module.metrics
    GLOBAL_TRAIN_SIZE = max(
        M["max_train_input_height"], M["max_train_input_width"],
        M["max_train_output_height"], M["max_train_output_width"]
    )
    GLOBAL_TEST_SIZE = max(
        M["max_test_input_height"], M["max_test_input_width"],
        M["max_test_output_height"], M["max_test_output_width"]
    )
    print("=== DATASET-WIDE PAD SIZES ===")
    print(f"TRAIN -> {GLOBAL_TRAIN_SIZE}x{GLOBAL_TRAIN_SIZE}")
    print(f"TEST  -> {GLOBAL_TEST_SIZE}x{GLOBAL_TEST_SIZE}")

    # Print up to 10 concise, readable examples
    printed = 0
    for batch in arc_loader:
        num_train = int(batch["train_inputs"].shape[0])
        num_test  = int(batch["test_inputs"].shape[0])

        # original (max over pairs before padding, per sample)
        train_orig_h, train_orig_w = map(int, batch["train_original_size"].tolist())
        test_orig_h,  test_orig_w  = map(int, batch["test_original_size"].tolist())

        # padded sizes (actual tensor shapes)
        train_in_h, train_in_w   = batch["train_inputs"].shape[1], batch["train_inputs"].shape[2]
        train_out_h, train_out_w = batch["train_outputs"].shape[1], batch["train_outputs"].shape[2]
        test_in_h,  test_in_w    = batch["test_inputs"].shape[1], batch["test_inputs"].shape[2]
        test_out_h, test_out_w   = batch["test_outputs"].shape[1], batch["test_outputs"].shape[2]

        # Validate against global expectations
        train_ok = (train_in_h == GLOBAL_TRAIN_SIZE == train_out_h) and (train_in_w == GLOBAL_TRAIN_SIZE == train_out_w)
        test_ok  = (test_in_h  == GLOBAL_TEST_SIZE  == test_out_h)  and (test_in_w  == GLOBAL_TEST_SIZE  == test_out_w)

        print(f"\n=== SUMMARY (sample {printed+1}) ===")
        print(f"id: {batch['id']}")
        print(f"#train: {num_train} | #test: {num_test}")
        print(f"Train original max: ({train_orig_h}, {train_orig_w})")
        print(f"Test  original max: ({test_orig_h}, {test_orig_w})")
        print(f"Padded sizes — train_in: ({train_in_h}, {train_in_w}), "
              f"train_out: ({train_out_h}, {train_out_w}), "
              f"test_in: ({test_in_h}, {test_in_w}), "
              f"test_out: ({test_out_h}, {test_out_w})")
        print(f"Matches global TRAIN size? {train_ok} | Matches global TEST size? {test_ok}")

        if num_train > 0:
            print("\n--- Example TRAIN pair [0] (cropped) ---")
            print("input:\n"  + _pretty_grid(batch["train_inputs"][0], 6, 10))
            print("output:\n" + _pretty_grid(batch["train_outputs"][0], 6, 10))
        if num_test > 0:
            print("\n--- Example TEST pair [0] (cropped) ---")
            print("input:\n"  + _pretty_grid(batch["test_inputs"][0], 6, 10))
            print("output:\n" + _pretty_grid(batch["test_outputs"][0], 6, 10))

        printed += 1
        if printed >= 10:
            break

    print("\nDataLoader type:", type(arc_loader))


=== DATASET-WIDE PAD SIZES ===
TRAIN -> 30x30
TEST  -> 30x30

=== SUMMARY (sample 1) ===
id: 1acc24af
#train: 4 | #test: 1
Train original max: (12, 12)
Test  original max: (12, 12)
Padded sizes — train_in: (30, 30), train_out: (30, 30), test_in: (30, 30), test_out: (30, 30)
Matches global TRAIN size? True | Matches global TEST size? True

--- Example TRAIN pair [0] (cropped) ---
input:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
[1, 1, 2, 2, 2, 2, 1, 1, 2, 2] ... 
[1, 1, 2, 1, 1, 2, 1, 1, 2, 1] ... 
[2, 2, 2, 1, 1, 2, 2, 2, 2, 1] ... 
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
...
output:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
[1, 1, 2, 2, 2, 2, 1, 1, 2, 2] ... 
[1, 1, 2, 1, 1, 2, 1, 1, 2, 1] ... 
[2, 2, 2, 1, 1, 2, 2, 2, 2, 1] ... 
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
...

--- Example TEST pair [0] (cropped) ---
input:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ... 
[1, 2, 2, 2, 1, 1, 1, 2, 2, 2] ... 
[1, 2, 