In [9]:
import os
import sys
import json
from pathlib import Path

import torch
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [15]:
def load_jsons_from_folder(folder_path):
    folder_path = os.path.expanduser(folder_path)
    folder_path = os.path.abspath(folder_path)

    data = {}

    if (not os.path.isdir(folder_path)):
        raise FileNotFoundError(f"Folder not found: {folder_path}")

    file_names = os.listdir(folder_path)
    file_names.sort()

    for name in file_names:
        if (not name.endswith(".json")):
            continue

        full_path = os.path.join(folder_path, name)
        try:
            with open(full_path, "r") as f:
                obj = json.load(f)
        except Exception as e:
            print("Failed to read:", full_path)
            print(" Error:", e)
            continue

        if ('train' not in obj) or ('test' not in obj):
            print("Skipping (no train/test):", full_path)
            continue

        # basic shape check
        ok = True
        for split in ['train', 'test']:
            if (not isinstance(obj[split], list)):
                ok = False
                break
            for pairs in obj[split]:
                if ('input' not in pairs) or ('output' not in pairs):
                    ok = False
                    break
        if (not ok):
            print("Skipping (bad format):", full_path)
            continue

        key = os.path.splitext(name)[0]
        data[key] = obj

    return data




def get_metrics(data):
    metric_dict = {
        "max_train_len": 0,
        "max_test_len": 0,
        "max_train_input_height": 0,
        "max_test_input_height": 0,
        "max_train_output_height": 0,
        "max_test_output_height": 0,
        "max_train_input_width": 0,
        "max_test_input_width": 0,
        "max_train_output_width": 0,
        "max_test_output_width": 0
    }

    for sample in data.values():
        if (len(sample['train']) > metric_dict['max_train_len']):
            metric_dict['max_train_len'] = len(sample['train'])
        if (len(sample['test']) > metric_dict['max_test_len']):
            metric_dict['max_test_len'] = len(sample['test'])
        for pairs in sample['train']:
            if (len(pairs['input']) > metric_dict['max_train_input_height']):
                metric_dict['max_train_input_height'] = len(pairs['input'])
            if (len(pairs['output']) > metric_dict['max_train_output_height']):
                metric_dict['max_train_output_height'] = len(pairs['output'])
            for inp in pairs['input']:
                if (len(inp) > metric_dict['max_train_input_width']):
                    metric_dict['max_train_input_width'] = len(inp)
            for output in pairs['output']:
                if (len(output) > metric_dict['max_train_output_width']):
                    metric_dict['max_train_output_width'] = len(output)
        for pairs in sample['test']:
            if (len(pairs['input']) > metric_dict['max_test_input_height']):
                metric_dict['max_test_input_height'] = len(pairs['input'])
            if (len(pairs['output']) > metric_dict['max_test_output_height']):
                metric_dict['max_test_output_height'] = len(pairs['output'])
            for inp in pairs['input']:
                if (len(inp) > metric_dict['max_test_input_width']):
                    metric_dict['max_test_input_width'] = len(inp)
            for output in pairs['output']:
                if (len(output) > metric_dict['max_test_output_width']):
                    metric_dict['max_test_output_width'] = len(output)
    return metric_dict



def pad_data(data, metric_dict, pad_value=0):
    for sample in data.values():
        # define max square sizes (per split)
        max_train_size = max(metric_dict['max_train_input_height'], metric_dict['max_train_input_width'],
                             metric_dict['max_train_output_height'], metric_dict['max_train_output_width'])
        max_test_size = max(metric_dict['max_test_input_height'], metric_dict['max_test_input_width'],
                            metric_dict['max_test_output_height'], metric_dict['max_test_output_width'])

        # pad training samples (grids only)
        for pairs in sample['train']:
            # pad input
            while len(pairs['input']) < max_train_size:
                pairs['input'].append([pad_value] * max_train_size)
            for inp in pairs['input']:
                while len(inp) < max_train_size:
                    inp.append(pad_value)
            # pad output
            while len(pairs['output']) < max_train_size:
                pairs['output'].append([pad_value] * max_train_size)
            for output in pairs['output']:
                while len(output) < max_train_size:
                    output.append(pad_value)

        # pad test samples (grids only)
        for pairs in sample['test']:
            # pad input
            while len(pairs['input']) < max_test_size:
                pairs['input'].append([pad_value] * max_test_size)
            for inp in pairs['input']:
                while len(inp) < max_test_size:
                    inp.append(pad_value)
            # pad output
            while len(pairs['output']) < max_test_size:
                pairs['output'].append([pad_value] * max_test_size)
            for output in pairs['output']:
                while len(output) < max_test_size:
                    output.append(pad_value)
    return data



def print_padded_data(data):
    for sample_name, sample in data.items():
        print("==================================================")
        print("SAMPLE:", sample_name)
        print("--------------------------------------------------")

        print("TRAINING PAIRS:")
        for idx, pairs in enumerate(sample['train']):
            print(f"  Train Pair {idx + 1}")
            print("  INPUT:")
            for row in pairs['input']:
                print("   ", row)
            print("  OUTPUT:")
            for row in pairs['output']:
                print("   ", row)
            print()

        print("--------------------------------------------------")
        print("TEST PAIRS:")
        for idx, pairs in enumerate(sample['test']):
            print(f"  Test Pair {idx + 1}")
            print("  INPUT:")
            for row in pairs['input']:
                print("   ", row)
            print("  OUTPUT:")
            for row in pairs['output']:
                print("   ", row)
            print()

        print("==================================================")
        print()


def _infer_original_size_from_padded(grid, pad_value=0):
    h = 0
    w = 0
    r = 0
    while r < len(grid):
        row = grid[r]
        any_nonpad = False
        last_nonpad = -1
        c = 0
        while c < len(row):
            if row[c] != pad_value:
                any_nonpad = True
                last_nonpad = c
            c += 1
        if any_nonpad:
            if (r + 1) > h:
                h = r + 1
            if (last_nonpad + 1) > w:
                w = last_nonpad + 1
        r += 1
    return (h, w)


def build_sample_level_dataset(data, pad_value=0):
    dataset = []
    for sample_name, sample in data.items():
        # containers
        train_pairs = []
        test_pairs = []

        # track original (unpadded) sizes per split
        train_max_h = 0
        train_max_w = 0
        test_max_h = 0
        test_max_w = 0

        # ----- TRAIN -----
        idx = 0
        for pairs in sample['train']:
            inp_grid = pairs['input']
            out_grid = pairs['output']

            # original sizes (prefer stored, else infer)
            if ('orig_input_size' in pairs):
                in_h, in_w = pairs['orig_input_size']
            else:
                in_h, in_w = _infer_original_size_from_padded(inp_grid, pad_value)
            if ('orig_output_size' in pairs):
                out_h, out_w = pairs['orig_output_size']
            else:
                out_h, out_w = _infer_original_size_from_padded(out_grid, pad_value)

            # update split-wide original size (max over inputs/outputs)
            if in_h > train_max_h: train_max_h = in_h
            if out_h > train_max_h: train_max_h = out_h
            if in_w > train_max_w: train_max_w = in_w
            if out_w > train_max_w: train_max_w = out_w

            # tensors
            inp_tensor = torch.tensor(inp_grid).long()
            out_tensor = torch.tensor(out_grid).long()

            # store pair
            train_pairs.append({
                "input": inp_tensor,
                "output": out_tensor
            })
            idx += 1

        # ----- TEST -----
        idx = 0
        for pairs in sample['test']:
            inp_grid = pairs['input']
            out_grid = pairs['output']

            if ('orig_input_size' in pairs):
                in_h, in_w = pairs['orig_input_size']
            else:
                in_h, in_w = _infer_original_size_from_padded(inp_grid, pad_value)
            if ('orig_output_size' in pairs):
                out_h, out_w = pairs['orig_output_size']
            else:
                out_h, out_w = _infer_original_size_from_padded(out_grid, pad_value)

            if in_h > test_max_h: test_max_h = in_h
            if out_h > test_max_h: test_max_h = out_h
            if in_w > test_max_w: test_max_w = in_w
            if out_w > test_max_w: test_max_w = out_w

            inp_tensor = torch.tensor(inp_grid).long()
            out_tensor = torch.tensor(out_grid).long()

            test_pairs.append({
                "input": inp_tensor,
                "output": out_tensor
            })
            idx += 1

        # assemble sample-level record
        item = {
            "id": str(sample_name),
            "train_pairs": train_pairs,
            "test_pairs": test_pairs,
            "train_original_size": (train_max_h, train_max_w),
            "test_original_size": (test_max_h, test_max_w)
        }
        dataset.append(item)

    return dataset


class ARCSampleDataset(Dataset):
    def __init__(self, sample_list):
        self.data = sample_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # stack per-sample pairs into tensors
        train_inputs = torch.stack([p["input"] for p in sample["train_pairs"]])   # [num_train, H, W]
        train_outputs = torch.stack([p["output"] for p in sample["train_pairs"]]) # [num_train, H, W]
        test_inputs = torch.stack([p["input"] for p in sample["test_pairs"]])     # [num_test, H, W]
        test_outputs = torch.stack([p["output"] for p in sample["test_pairs"]])   # [num_test, H, W]

        return {
            "id": sample["id"],
            "train_inputs": train_inputs,
            "train_outputs": train_outputs,
            "test_inputs": test_inputs,
            "test_outputs": test_outputs,
            "train_original_size": torch.tensor(sample["train_original_size"], dtype=torch.long),
            "test_original_size": torch.tensor(sample["test_original_size"], dtype=torch.long)
        }

def arc_collate_fn_bs1(batch):
    # batch size is guaranteed to be 1; return the single dict unchanged
    return batch[0]



folder_path = f"~/ARC-AGI-Model/src/data_pipeline/ARC_data/data/training"
data = load_jsons_from_folder(folder_path)
metrics = get_metrics(data)
padded_data = pad_data(data, metrics)
sample_level = build_sample_level_dataset(padded_data, pad_value=0)
arc_dataset = ARCSampleDataset(sample_list=sample_level)

arc_loader = DataLoader(
    arc_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=arc_collate_fn_bs1,
    num_workers=0,
    pin_memory=False
)

for batch in arc_loader:
    print("ID:", batch["id"])
    print("train_inputs:", batch["train_inputs"].shape)   # [num_train, H, W]
    print("train_outputs:", batch["train_outputs"].shape) # [num_train, H, W]
    print("test_inputs:", batch["test_inputs"].shape)     # [num_test, H, W]
    print("test_outputs:", batch["test_outputs"].shape)   # [num_test, H, W]
    print("train_original_size:", batch["train_original_size"])
    print("test_original_size:", batch["test_original_size"])
    print()
    


ID: 7e0986d6
train_inputs: torch.Size([2, 30, 30])
train_outputs: torch.Size([2, 30, 30])
test_inputs: torch.Size([1, 30, 30])
test_outputs: torch.Size([1, 30, 30])
train_original_size: tensor([13, 16])
test_original_size: tensor([12, 17])
