In [None]:
!pip install cax
!pip install mediapy

In [None]:
import json
import os

import jax
import jaxlib
import jax.numpy as jnp
import mediapy
import optax
import cax
from cax.core.ca import CA
from cax.core.perceive.depthwise_conv_perceive import DepthwiseConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.update.residual_update import ResidualUpdate
import flax
from flax import nnx
from tqdm.auto import tqdm

print(f"jax {jax.__version__}")
print(f"jaxlib {jaxlib.__version__}")
print(f"cax {cax.__version__}")
print(f"flax {flax.__version__}")
print(f"optax {optax.__version__}")

import numpy as np
import pandas as pd

# List all files under the input directory
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Paths to the dataset files
training_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
training_solutions_path = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
evaluation_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
evaluation_solutions_path = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'

# Load the training data
with open(training_challenges_path, 'r') as f:
    training_challenges = json.load(f)
with open(training_solutions_path, 'r') as f:
    training_solutions = json.load(f)

# Load the evaluation data
with open(evaluation_challenges_path, 'r') as f:
    evaluation_challenges = json.load(f)
with open(evaluation_solutions_path, 'r') as f:
    evaluation_solutions = json.load(f)

def process_tasks(challenges, solutions):
    inputs = []
    outputs = []
    for task_id in challenges.keys():
        task = challenges[task_id]
        solution = solutions[task_id]
        # For each task, get train pairs
        train_pairs = task['train']
        for pair in train_pairs:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            inputs.append(input_grid)
            outputs.append(output_grid)
    return inputs, outputs

# Process training data
training_inputs, training_outputs = process_tasks(training_challenges, training_solutions)

# Process evaluation data
evaluation_inputs, evaluation_outputs = process_tasks(evaluation_challenges, evaluation_solutions)

def get_max_grid_size(grids):
    max_rows = max(len(grid) for grid in grids)
    max_cols = max(len(grid[0]) if len(grid) > 0 else 0 for grid in grids)
    return max_rows, max_cols

# Determine the maximum grid size for padding
max_input_rows, max_input_cols = get_max_grid_size(training_inputs + evaluation_inputs)
max_output_rows, max_output_cols = get_max_grid_size(training_outputs + evaluation_outputs)

# Set a fixed grid size (e.g., 30x30 as per dataset description)
fixed_rows, fixed_cols = 30, 30

def pad_grid(grid, max_rows=fixed_rows, max_cols=fixed_cols, pad_value=0):
    padded_grid = np.full((max_rows, max_cols), pad_value, dtype=np.int32)
    rows = len(grid)
    cols = len(grid[0]) if rows > 0 else 0
    padded_grid[:rows, :cols] = grid
    return padded_grid

# Pad all grids in the training data
padded_training_inputs = [pad_grid(grid) for grid in training_inputs]
padded_training_outputs = [pad_grid(grid) for grid in training_outputs]

# Pad all grids in the evaluation data
padded_evaluation_inputs = [pad_grid(grid) for grid in evaluation_inputs]
padded_evaluation_outputs = [pad_grid(grid) for grid in evaluation_outputs]

# Convert lists to numpy arrays
training_inputs_array = np.stack(padded_training_inputs)
training_outputs_array = np.stack(padded_training_outputs)
evaluation_inputs_array = np.stack(padded_evaluation_inputs)
evaluation_outputs_array = np.stack(padded_evaluation_outputs)

print(f"Training inputs shape: {training_inputs_array.shape}")
print(f"Training outputs shape: {training_outputs_array.shape}")
print(f"Evaluation inputs shape: {evaluation_inputs_array.shape}")
print(f"Evaluation outputs shape: {evaluation_outputs_array.shape}")

# Now the data is ready for JAX/Flax training
# You can proceed to define your model and training loop

# Example: Create a simple JAX dataset iterator
def data_generator(inputs, outputs, batch_size):
    num_samples = inputs.shape[0]
    indices = np.arange(num_samples)
    np.random.shuffle(indices)
    for start_idx in range(0, num_samples - batch_size + 1, batch_size):
        batch_indices = indices[start_idx:start_idx + batch_size]
        yield inputs[batch_indices], outputs[batch_indices]

# Example usage of data generator
batch_size = 32
train_data_gen = data_generator(training_inputs_array, training_outputs_array, batch_size)

# Fetch one batch of data
batch_inputs, batch_outputs = next(train_data_gen)
print(f"Batch inputs shape: {batch_inputs.shape}")
print(f"Batch outputs shape: {batch_outputs.shape}")