In [None]:
!pip install cax
!pip install mediapy

In [None]:
import json
import os

import jax
import jaxlib
import jax.numpy as jnp
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import flax
from flax import nnx
import optax
import cax
from cax.core.ca import CA
from cax.core.perceive.depthwise_conv_perceive import DepthwiseConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.update.residual_update import ResidualUpdate

print(f"jax {jax.__version__}")
print(f"jaxlib {jaxlib.__version__}")
print(f"cax {cax.__version__}")
print(f"flax {flax.__version__}")
print(f"optax {optax.__version__}")

# List all files under the input directory
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Paths to the dataset files
training_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
training_solutions_path = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
evaluation_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
evaluation_solutions_path = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'

# Load the training data
with open(training_challenges_path, 'r') as f:
    training_challenges = json.load(f)
with open(training_solutions_path, 'r') as f:
    training_solutions = json.load(f)

# Load the evaluation data
with open(evaluation_challenges_path, 'r') as f:
    evaluation_challenges = json.load(f)
with open(evaluation_solutions_path, 'r') as f:
    evaluation_solutions = json.load(f)

def process_tasks(challenges, solutions):
    inputs = []
    outputs = []
    task_indices = []
    task_id_to_index = {}
    index = 0
    for task_id in challenges.keys():
        if task_id not in task_id_to_index:
            task_id_to_index[task_id] = index
            index += 1
        task_index = task_id_to_index[task_id]
        task = challenges[task_id]
        solution = solutions[task_id]
        # For each task, get train pairs
        train_pairs = task['train']
        test_pairs = task['test']
        # Combine train pairs for training
        for pair in train_pairs:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            inputs.append(input_grid)
            outputs.append(output_grid)
            task_indices.append(task_index)
        # Use test pairs and solutions
        for i, test_input in enumerate(test_pairs):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            output_grid = np.array(solution[i], dtype=np.int32)
            inputs.append(input_grid)
            outputs.append(output_grid)
            task_indices.append(task_index)
    return inputs, outputs, task_indices, task_id_to_index

# Process training data
training_inputs, training_outputs, training_task_indices, task_id_to_index = process_tasks(training_challenges, training_solutions)

# Process evaluation data
evaluation_inputs, evaluation_outputs, evaluation_task_indices, _ = process_tasks(evaluation_challenges, evaluation_solutions)

def pad_grid(grid, max_rows=30, max_cols=30, pad_value=0):
    padded_grid = np.full((max_rows, max_cols), pad_value, dtype=np.int32)
    rows = grid.shape[0]
    cols = grid.shape[1]
    padded_grid[:rows, :cols] = grid
    return padded_grid

# Pad all grids in the training data
padded_training_inputs = [pad_grid(grid) for grid in training_inputs]
padded_training_outputs = [pad_grid(grid) for grid in training_outputs]

# Pad all grids in the evaluation data
padded_evaluation_inputs = [pad_grid(grid) for grid in evaluation_inputs]
padded_evaluation_outputs = [pad_grid(grid) for grid in evaluation_outputs]

# Convert lists to numpy arrays
training_inputs_array = np.stack(padded_training_inputs)
training_outputs_array = np.stack(padded_training_outputs)
evaluation_inputs_array = np.stack(padded_evaluation_inputs)
evaluation_outputs_array = np.stack(padded_evaluation_outputs)

training_task_indices_array = np.array(training_task_indices, dtype=np.int32)
evaluation_task_indices_array = np.array(evaluation_task_indices, dtype=np.int32)

print(f"Training inputs shape: {training_inputs_array.shape}")
print(f"Training outputs shape: {training_outputs_array.shape}")
print(f"Training task indices shape: {training_task_indices_array.shape}")
print(f"Evaluation inputs shape: {evaluation_inputs_array.shape}")
print(f"Evaluation outputs shape: {evaluation_outputs_array.shape}")
print(f"Evaluation task indices shape: {evaluation_task_indices_array.shape}")

# Convert numpy arrays to JAX arrays
training_inputs_array = jnp.array(training_inputs_array)
training_outputs_array = jnp.array(training_outputs_array)
training_task_indices_array = jnp.array(training_task_indices_array)
evaluation_inputs_array = jnp.array(evaluation_inputs_array)
evaluation_outputs_array = jnp.array(evaluation_outputs_array)
evaluation_task_indices_array = jnp.array(evaluation_task_indices_array)

# Set up training parameters
seed = 0
key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(seed)

channel_size = 32
num_spatial_dims = 2  # Set to 2 for 2D grids
num_kernels = 2
hidden_size = 256
cell_dropout_rate = 0.5
batch_size = 2
num_steps = 2
learning_rate = 1e-3
ds_size = 30  # Grid size
num_train_steps = 4
print_interval = 2

task_list = list(task_id_to_index.keys())
num_tasks = len(task_list)

print(f"Number of tasks: {num_tasks}")

# Define functions to initialize state
def init_state(key):
    idx = jax.random.randint(key, (), 0, training_inputs_array.shape[0])
    input_grid = training_inputs_array[idx]
    target_grid = training_outputs_array[idx]
    task_index = training_task_indices_array[idx]
    state = jnp.zeros((ds_size, ds_size, channel_size))
    # Initialize the first channel with the input grid
    state = state.at[..., 0].set(input_grid)
    return state, target_grid, task_index

def init_state_test(key):
    idx = jax.random.randint(key, (), 0, evaluation_inputs_array.shape[0])
    input_grid = evaluation_inputs_array[idx]
    target_grid = evaluation_outputs_array[idx]
    task_index = evaluation_task_indices_array[idx]
    state = jnp.zeros((ds_size, ds_size, channel_size))
    # Initialize the first channel with the input grid
    state = state.at[..., 0].set(input_grid)
    return state, target_grid, task_index

# Define the NCA model
class EmbedCA(CA):
    embed_input: nnx.Embed
    embed_task: nnx.Embed

    def __init__(self, perceive, update, embed_input, embed_task):
        super().__init__(perceive, update)
        self.embed_input = embed_input
        self.embed_task = embed_task

    def __call__(self, state, task_embed, num_steps=1, all_steps=False):
        steps = []
        for _ in range(num_steps):
            state = self.step(state, task_embed)
            if all_steps:
                steps.append(state)
        if all_steps:
            return jnp.stack(steps)
        else:
            return state

# Custom DepthwiseConvPerceive with print statements
class MyDepthwiseConvPerceive(DepthwiseConvPerceive):
    def __call__(self, state: jnp.ndarray) -> jnp.ndarray:
        print(f"Perceive input state shape: {state.shape}")
        print(f"Depthwise conv kernel shape: {self.depthwise_conv.kernel.value.shape}")
        print(f"Depthwise conv features: {self.depthwise_conv.features}")
        print(f"Depthwise conv feature group count: {self.depthwise_conv.feature_group_count}")
        perception = self.depthwise_conv(state)
        return perception

# Initialize the NCA model components
perceive = MyDepthwiseConvPerceive(channel_size, rngs, num_kernels=num_kernels, kernel_size=(3, 3))
update = ResidualUpdate(
    num_spatial_dims,
    channel_size,
    num_kernels * channel_size + 8,
    (hidden_size,),
    rngs,
    cell_dropout_rate=cell_dropout_rate,
)
embed_input = nnx.Embed(num_embeddings=10, features=3, rngs=rngs)
embed_task = nnx.Embed(num_embeddings=num_tasks, features=8, rngs=rngs)

# Initialize the NCA model
ca = EmbedCA(perceive, update, embed_input, embed_task)

# Initialize the kernels
identity = identity_kernel(ndim=2)  # Shape: (3, 3, 1)
gradient = grad_kernel(ndim=2)      # Shape: (3, 3, 2)
print(f"identity kernel shape: {identity.shape}")
print(f"gradient kernel shape: {gradient.shape}")

# Stack the kernels along the last axis to get shape (3, 3, 3)
base_kernel = jnp.concatenate([identity, gradient], axis=-1)  # Shape: (3, 3, 3)
print(f"base_kernel shape after concatenation: {base_kernel.shape}")

# Expand dimensions to match required shape
base_kernel = base_kernel[:, :, None, :]  # Shape: (3, 3, 1, 3)
print(f"base_kernel shape after expand_dims: {base_kernel.shape}")

# Total number of features (output channels)
features = channel_size * num_kernels
print(f"channel_size: {channel_size}")
print(f"num_kernels: {num_kernels}")
print(f"Total features (output channels): {features}")

# Calculate the number of tiles needed
tiles = int(np.ceil(features / base_kernel.shape[-1]))
print(f"Number of tiles: {tiles}")

# Tile the base_kernel to match the number of features
kernel = jnp.tile(base_kernel, (1, 1, 1, tiles))
print(f"kernel shape after tiling: {kernel.shape}")

# Slice the kernel to get the exact number of features needed
kernel = kernel[:, :, :, :features]
print(f"kernel shape after slicing: {kernel.shape}")

# Verify the kernel shape matches the expected shape
expected_kernel_shape = (3, 3, 1, features)
print(f"Expected kernel shape: {expected_kernel_shape}")
assert kernel.shape == expected_kernel_shape, "Kernel shape does not match expected shape."

# Assign the kernel to the depthwise convolution layer
perceive.depthwise_conv.kernel = nnx.Param(kernel)
print(f"Assigned kernel shape: {perceive.depthwise_conv.kernel.value.shape}")
print(f"Depthwise conv features: {perceive.depthwise_conv.features}")
print(f"Feature group count: {perceive.depthwise_conv.feature_group_count}")

# Extract parameters
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree_util.tree_reduce(lambda x, y: x + y.size, params, 0))

# Set up the optimizer
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=2000)
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=lr_sched),
)
update_params = nnx.All(nnx.Param, nnx.PathContains("update"))
optimizer = nnx.Optimizer(ca, optimizer, wrt=update_params)

# Define the loss and accuracy functions
def mse(state, target):
    return jnp.mean(jnp.square(state[..., :3] - target))

@nnx.jit
def accuracy_fn(state, target):
    predictions = jnp.argmax(state[..., :3], axis=-1)
    correct = jnp.sum(predictions == target)
    total = target.size
    return correct / total

@nnx.jit
def loss_fn(ca, state, target, task_index):
    input_grid = state[..., 0]
    input_embed = ca.embed_input(jnp.asarray(input_grid, dtype=jnp.int32))
    task_embed = ca.embed_task(jnp.asarray(task_index, dtype=jnp.int32))
    state = state.at[..., :3].set(input_embed)
    target_embed = ca.embed_input(jnp.asarray(target, dtype=jnp.int32))
    state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
    # Print shapes before calling ca
    print(f"State shape before CA: {state.shape}")
    print(f"Task embed shape: {task_embed.shape}")
    state = nnx.split_rngs(splits=batch_size)(
        nnx.vmap(
            lambda ca, state, task_embed: ca(state, task_embed, num_steps=num_steps),
            in_axes=(state_axes, 0, 0),
        )
    )(ca, state, task_embed)
    loss = mse(state, target_embed)
    return loss

@nnx.jit
def train_step(ca, optimizer, key):
    keys = jax.random.split(key, batch_size)
    state, target, task_index = jax.vmap(init_state)(keys)
    loss, grad = nnx.value_and_grad(loss_fn, argnums=nnx.DiffState(0, update_params))(ca, state, target, task_index)
    optimizer.update(grad)
    return loss

@nnx.jit
def eval_step(ca, key):
    keys = jax.random.split(key, batch_size)
    state, target, task_index = jax.vmap(init_state_test)(keys)
    accuracy = accuracy_fn(state, target)
    return accuracy

# Training loop
pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []
eval_accuracies = []

for i in pbar:
    key, subkey = jax.random.split(key)
    loss = train_step(ca, optimizer, subkey)
    losses.append(loss)

    if i % print_interval == 0 or i == num_train_steps - 1:
        avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
        pbar.set_postfix({"Average Loss": f"{avg_loss:.6f}"})
        accuracy = eval_step(ca, subkey)
        eval_accuracies.append(accuracy)
        avg_accuracy = sum(eval_accuracies[-print_interval:]) / len(eval_accuracies[-print_interval:])
        print(f"Step {i}, Average Loss: {avg_loss:.6f}, Eval Accuracy: {avg_accuracy:.4f}")

# Prepare submission
test_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'

with open(test_challenges_path, 'r') as f:
    test_challenges = json.load(f)

submission = dict()
for task_id, task in test_challenges.items():
    test_pairs = task['test']
    outputs = []
    for test_input in test_pairs:
        input_grid = np.array(test_input['input'], dtype=np.int32)
        padded_input = pad_grid(input_grid)
        # Initialize state
        state = np.zeros((ds_size, ds_size, channel_size), dtype=np.float32)
        state[..., 0] = padded_input
        # Embed input and task
        input_embed = ca.embed_input(jnp.asarray(state[..., 0], dtype=jnp.int32))
        # Assuming task_index is known, else set to 0
        task_index = task_id_to_index.get(task_id, 0)
        task_embed = ca.embed_task(jnp.asarray(task_index, dtype=jnp.int32))
        state = jnp.array(state)
        state = state.at[..., :3].set(input_embed)
        # Run the model
        state_axes = nnx.StateAxes({nnx.RngState: None, ...: None})
        state = ca(state, task_embed, num_steps=num_steps)
        # Get the output
        output_grid = jnp.argmax(state[..., :3], axis=-1)
        output_grid = output_grid.astype(int)
        # Remove padding if necessary (depends on original input size)
        original_shape = input_grid.shape
        output_grid = output_grid[:original_shape[0], :original_shape[1]]
        outputs.append({"output": output_grid.tolist()})
    submission[task_id] = outputs

# Save the submission
with open('submission.json', 'w') as f:
    json.dump(submission, f)
