In [None]:
!pip install 'pytest>=8.3.2' 'numpy>=1.26.4' 'pillow>=10.4.0' 'msgpack>=1.1.0' 'requests>=2.32.3' 'mediapy>=1.2.2' tqdm
!pip install --no-deps 'optax==0.2.3' 'chex==0.1.86' 'flax>=0.9.0' orbax-checkpoint tensorstore 'typing-extensions>=4.2' 'absl-py>=2.1.0' 'toolz>=1.0.0' 'etils[epy]>=1.9.4'
!pip install wandb
!wandb login
!git clone https://github.com/hu-po/cax.git /cax
!pip install --upgrade /cax --no-deps

In [None]:
import json
import os

import jax
import jaxlib
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
import wandb

import flax
from flax import nnx
import optax
import cax

for pkg in [jax, jaxlib, cax, flax, optax]:
    print(pkg.__name__, pkg.__version__)

In [None]:
morph =  os.environ["MORPH"]
morph_nb_filepath =  os.environ["MORPH_NB_FILEPATH"]
morph_output_dir =  os.environ["MORPH_OUTPUT_DIR"]

In [None]:
def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} tasks from {path}")
    return data

train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

def process_tasks(challenges, solutions):
    inputs, outputs, task_indices = [], [], []
    task_id_to_index = {}
    for index, task_id in enumerate(challenges.keys()):
        task_id_to_index[task_id] = index
        task = challenges[task_id]
        solution = solutions[task_id]
        for pair in task['train']:
            inputs.append(np.array(pair['input'], dtype=np.int32))
            outputs.append(np.array(pair['output'], dtype=np.int32))
            task_indices.append(index)
        for i, test_input in enumerate(task['test']):
            inputs.append(np.array(test_input['input'], dtype=np.int32))
            outputs.append(np.array(solution[i], dtype=np.int32))
            task_indices.append(index)
    return inputs, outputs, task_indices, task_id_to_index

def pad_grids(grids, max_size=30, pad_value=0):
    padded_grids = []
    for grid in grids:
        padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
        rows, cols = grid.shape
        padded[:rows, :cols] = grid
        padded_grids.append(padded)
    return np.stack(padded_grids)

def prepare_data(challenges, solutions):
    inputs, outputs, task_indices, task_id_to_index = process_tasks(challenges, solutions)
    print(f"\t number of samples: {len(inputs)}")
    inputs_array = pad_grids(inputs)
    outputs_array = pad_grids(outputs)
    task_indices_array = np.array(task_indices, dtype=np.int32)
    inputs_array = jnp.array(inputs_array)
    outputs_array = jnp.array(outputs_array)
    task_indices_array = jnp.array(task_indices_array)
    return inputs_array, outputs_array, task_indices_array, task_id_to_index

print("Processing train data...")
train_inputs, train_outputs, train_task_indices, task_id_to_index = prepare_data(train_challenges, train_solutions)
print(f"\t inputs shape: {train_inputs.shape}")
print(f"\t outputs shape: {train_outputs.shape}")
print(f"\t task indices shape: {train_task_indices.shape}")

print("Processing eval data...")
eval_inputs, eval_outputs, eval_task_indices, _ = prepare_data(eval_challenges, eval_solutions)
print(f"\t inputs shape: {eval_inputs.shape}")
print(f"\t outputs shape: {eval_outputs.shape}")
print(f"\t task indices shape: {eval_task_indices.shape}")

In [None]:
#<code>

In [None]:
test_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
prepare_submission(ca, test_challenges_path)