In [None]:
!pip install 'pytest>=8.3.2' 'numpy>=1.26.4' 'pillow>=10.4.0' 'msgpack>=1.1.0' 'requests>=2.32.3' 'mediapy>=1.2.2' tqdm
!pip install --no-deps 'optax==0.2.3' 'chex==0.1.86' 'flax>=0.9.0' orbax-checkpoint tensorstore 'typing-extensions>=4.2' 'absl-py>=2.1.0' 'toolz>=1.0.0' 'etils[epy]>=1.9.4'
!pip install wandb
!wandb login
!git clone https://github.com/hu-po/cax.git /cax
!pip install --upgrade /cax --no-deps

In [None]:
from functools import partial
import json
import os
import subprocess

import jax
import jaxlib
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
import wandb

import flax
from flax import nnx
import optax
import cax

for pkg in [jax, jaxlib, cax, flax, optax]:
    print(pkg.__name__, pkg.__version__)

In [None]:
morph =  os.environ["MORPH"]
morph_nb_filepath =  os.environ["MORPH_NB_FILEPATH"]
morph_output_dir =  os.environ["MORPH_OUTPUT_DIR"]

In [None]:

def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

# Process tasks
def process_tasks(challenges, solutions):
    data = []
    task_id_to_index = {}
    for index, task_id in enumerate(challenges.keys()):
        task_id_to_index[task_id] = index
        task = challenges[task_id]
        solution = solutions[task_id]
        for pair in task['train']:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            data.append((input_grid, output_grid, index))
        for i, test_input in enumerate(task['test']):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            output_grid = np.array(solution[i], dtype=np.int32)
            data.append((input_grid, output_grid, index))
    return data, task_id_to_index

# Pad grid function
def pad_grid(grid, max_size=30, pad_value=0, center_offset=(0, 0)):
    padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
    rows, cols = grid.shape
    offset_row = (max_size - rows) // 2 + center_offset[0]
    offset_col = (max_size - cols) // 2 + center_offset[1]
    offset_row = np.clip(offset_row, 0, max_size - rows)
    offset_col = np.clip(offset_col, 0, max_size - cols)
    padded[offset_row:offset_row+rows, offset_col:offset_col+cols] = grid
    return padded

# Augmentation functions using JAX
@jax.jit
def flip_horizontal(grid):
    return jnp.fliplr(grid)

@jax.jit
def flip_vertical(grid):
    return jnp.flipud(grid)

@partial(jax.jit, static_argnums=(1,))
def rotate(grid, k):
    return jnp.rot90(grid, k=k)

AUGMENTATIONS = {
    'flip_horizontal': flip_horizontal,
    'flip_vertical': flip_vertical,
    'rotate_90': partial(rotate, k=1),
    'rotate_180': partial(rotate, k=2),
    'rotate_270': partial(rotate, k=3),
}

# Apply random augmentations to a batch
def apply_random_augmentations_batch(input_grids, output_grids, rng_keys, augmentations):
    num_augs = len(augmentations)

    def augment_sample(input_grid, output_grid, rng_key):
        apply_aug = jax.random.bernoulli(rng_key, p=0.5, shape=(num_augs,))
        for i, aug_name in enumerate(augmentations):
            aug_func = AUGMENTATIONS[aug_name]
            input_grid, output_grid = jax.lax.cond(
                apply_aug[i],
                lambda grids: (aug_func(grids[0]), aug_func(grids[1])),
                lambda grids: grids,
                (input_grid, output_grid)
            )
        return input_grid, output_grid

    # Vectorize over the batch dimension
    v_augment_sample = jax.vmap(augment_sample, in_axes=(0, 0, 0), out_axes=(0, 0))
    augmented_inputs, augmented_outputs = v_augment_sample(input_grids, output_grids, rng_keys)
    return augmented_inputs, augmented_outputs

# Data generator
def data_generator(data, batch_size, augmentations=None, max_size=30, pad_value=0, center_offset=(0, 0), shuffle=True):
    num_samples = len(data)
    indices = np.arange(num_samples)
    epoch = 0
    while True:
        if shuffle:
            rng = np.random.default_rng(seed=epoch)
            rng.shuffle(indices)
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_inputs = []
            batch_outputs = []
            batch_task_indices = []
            for idx in excerpt:
                input_grid, output_grid, task_index = data[idx]
                input_padded = pad_grid(input_grid, max_size, pad_value, center_offset)
                output_padded = pad_grid(output_grid, max_size, pad_value, center_offset)
                batch_inputs.append(input_padded)
                batch_outputs.append(output_padded)
                batch_task_indices.append(task_index)
            batch_inputs = jnp.stack(batch_inputs)
            batch_outputs = jnp.stack(batch_outputs)
            batch_task_indices = jnp.array(batch_task_indices, dtype=jnp.int32)
            if augmentations:
                rng_key = jax.random.PRNGKey(epoch)
                rng_key = jax.random.fold_in(rng_key, start_idx)
                rng_keys = jax.random.split(rng_key, len(batch_inputs))
                batch_inputs, batch_outputs = apply_random_augmentations_batch(batch_inputs, batch_outputs, rng_keys, augmentations)
            yield batch_inputs, batch_outputs, batch_task_indices
        epoch += 1

train_data, task_id_to_index = process_tasks(train_challenges, train_solutions)
eval_data, _ = process_tasks(eval_challenges, eval_solutions)

augmentations = ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']
batch_size = 32

train_generator = data_generator(train_data, batch_size, augmentations=augmentations)
eval_generator = data_generator(eval_data, batch_size, augmentations=None, shuffle=False)


In [None]:
#<code>
# Define a training step function
@jax.jit
def train_step(params, batch_inputs, batch_outputs, batch_task_indices, opt_state):
    def loss_fn(params):
        preds = model.apply(params, batch_inputs)
        loss = compute_loss(preds, batch_outputs)
        return loss
    grads = jax.grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# Assuming you have defined your model, loss function, and optimizer
# Initialize model parameters and optimizer state
params = model.init(jax.random.PRNGKey(0), jnp.ones((batch_size, 30, 30)))
opt_state = optimizer.init(params)

# Training loop
num_epochs = 10
steps_per_epoch = len(train_data) // batch_size

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for step in range(steps_per_epoch):
        batch_inputs, batch_outputs, batch_task_indices = next(train_generator)
        # Your training code here
        # For example: params, opt_state = train_step(params, batch_inputs, batch_outputs, batch_task_indices, opt_state)

In [None]:
test_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
prepare_submission(ca, test_challenges_path)