In [None]:
from dataclasses import dataclass
from datetime import datetime
import json
import os
import psutil
import sys

import numpy as np

@dataclass(frozen=True)
class BaseConfig:
    seed: int = int(os.environ.get("SEED", 42))
    # --- data
    train_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
    train_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
    valid_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
    valid_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'
    num_time_augs: int = 8
    num_color_augs: int = 32
    # --- logging 
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")

def load_tasks(challenges_path: str, solutions_path: str, cfg: BaseConfig):
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    """
    tasks are stored in JSON format. Each JSON file consists of two key-value pairs.
    train: a list of two to ten input/output pairs (typically three.) These are used for your algorithm to infer a rule.
    test: a list of one to three input/output pairs (typically one.) Your model should apply the inferred rule from the train set and construct an output solution.
    """
    tasks = []
    for task_id in challenges_dict.keys():
        print(f"\t task {task_id}")
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        """
        a "grid" is a rectangular matrix (list of lists) of integers between 0 and 9 (inclusive).
        the smallest possible grid size is 1x1 and the largest is 30x30.
        0 represents the background color, 1-9 represent the pattern colors.
        """
        for pair in challenges_dict[task_id]['train']:
            _task_train_in = np.array(pair['input'], dtype=np.uint8) # store as uint8 to save system memory
            _task_train_out = np.array(pair['output'], dtype=np.uint8)
            task_train_in.append(_task_train_in)
            task_train_out.append(_task_train_out)
            print(f"shape of task_train_in {_task_train_in.shape}")
            print(f"shape of task_train_out {_task_train_out.shape}")
        for grid in challenges_dict[task_id]['test']:
            _task_eval_in = np.array(grid['input'], dtype=np.uint8)
            task_eval_in.append(grid['input'])
            print(f"shape of task_eval_in {_task_eval_in.shape}")
        assert task_id in solutions_dict
        for grid in solutions_dict[task_id]:
            _grid = np.array(grid, dtype=np.uint8)
            task_eval_out.append(_grid)
            print(f"shape of task_eval_out {_grid.shape}")
        assert len(task_train_in) == len(task_train_out)
        assert len(task_train_in) <= 10 # maximum number of input/output pairs 
        assert len(task_eval_in) == len(task_eval_out)
        assert len(task_eval_in) <= 3 # maximum number of input/output pairs
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

def augmentation(tasks, cfg: BaseConfig):
    """
    basic "spatial" augmentation of grids: flipping (lr and ud), rotating (90 and 270)
    basic "time" augmentation of tasks: changing the order of the training pairs
    basic "channel" augmentation of grids: change the colors used in the grid (except for 0 the background color)
    by pre-augmenting the dataset, we increase the size of the dataset on system memory,
    since the dataset is small, the tradeoff of less gpu compute at train time is worth it.
    """
    np.random.seed(cfg.seed)
    augmented_tasks = []
    for task in tasks:
        train_in, train_out, eval_in, eval_out = task
        augs = [task]  # Start with the original task
        spatial_augs = [np.fliplr, np.flipud, lambda x: np.rot90(x, 1), lambda x: np.rot90(x, 3)]
        augs.extend([
            ([aug(grid) for grid in train_in],
             [aug(grid) for grid in train_out],
             [aug(grid) for grid in eval_in],
             [aug(grid) for grid in eval_out])
            for aug in spatial_augs
        ])
        if len(train_in) > 1:
            augs.extend([
                ([train_in[i] for i in np.random.permutation(len(train_in))],
                 [train_out[i] for i in np.random.permutation(len(train_out))],
                 eval_in, eval_out)
                for _ in range(cfg.num_time_augs)
            ])
        for _ in range(cfg.num_color_augs):
            color_map = np.arange(10)
            np.random.shuffle(color_map[1:])  # keep 0 as background color
            augs.append((
                [color_map[grid] for grid in train_in],
                [color_map[grid] for grid in train_out],
                [color_map[grid] for grid in eval_in],
                [color_map[grid] for grid in eval_out]
            ))
        
        augmented_tasks.extend(augs)
    return augmented_tasks

cfg = BaseConfig()

def get_memory_usage():
    """Returns current system memory usage in MB, total memory in MB, and percentage used."""
    mem = psutil.virtual_memory()
    used_memory = mem.used / (1024 * 1024)  # Convert to MB
    total_memory = mem.total / (1024 * 1024)  # Convert to MB
    percent_used = mem.percent
    return used_memory, total_memory, percent_used

def print_memory_status(stage: str):
    """Prints system memory usage and percentage."""
    used_memory, total_memory, percent_used = get_memory_usage()
    print(f"[{stage}] Memory usage: {used_memory:.2f} MB / {total_memory:.2f} MB ({percent_used:.2f}% used)")

def print_array_size(arr, name):
    """Prints the size of a Numpy array in bytes."""
    size = sys.getsizeof(arr)
    print(f"Array {name} has size {size} bytes, shape {arr.shape}, dtype {arr.dtype}")

# Memory usage reporting and task loading
print_memory_status("Before loading tasks")
train_tasks = load_tasks(cfg.train_challenges, cfg.train_solutions, cfg)
valid_tasks = load_tasks(cfg.valid_challenges, cfg.valid_solutions, cfg)
print_memory_status("After loading tasks")

# Augmentation
train_tasks = augmentation(train_tasks, cfg)
print(f"Augmented train tasks: {len(train_tasks)}")
print_memory_status("After augmenting train tasks")