In [None]:
!pip install 'pytest>=8.3.2' 'numpy>=1.26.4' 'pillow>=10.4.0' 'msgpack>=1.1.0' 'requests>=2.32.3' 'mediapy>=1.2.2' tqdm
!pip install --no-deps 'optax==0.2.3' 'chex==0.1.86' 'flax>=0.9.0' orbax-checkpoint tensorstore 'typing-extensions>=4.2' 'absl-py>=2.1.0' 'toolz>=1.0.0' 'etils[epy]>=1.9.4'
!pip install wandb
!wandb login
!git clone https://github.com/hu-po/cax.git /cax
!pip install --upgrade /cax --no-deps

In [None]:
import json
import os

import jax
import jaxlib
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
import wandb

import flax
from flax import nnx
from flax.training import train_state
import optax
import cax

for pkg in [jax, jaxlib, cax, flax, optax]:
    print(pkg.__name__, pkg.__version__)

In [None]:
morph =  os.environ["MORPH"]
morph_nb_filepath =  os.environ["MORPH_NB_FILEPATH"]
morph_output_dir =  os.environ["MORPH_OUTPUT_DIR"]

In [None]:
# Data loading and preprocessing
def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} tasks from {path}")
    return data

train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

def process_tasks(challenges, solutions):
    inputs, outputs, task_indices = [], [], []
    task_id_to_index = {}
    for index, task_id in enumerate(challenges.keys()):
        task_id_to_index[task_id] = index
        task = challenges[task_id]
        solution = solutions[task_id]
        for pair in task['train']:
            inputs.append(np.array(pair['input'], dtype=np.int32))
            outputs.append(np.array(pair['output'], dtype=np.int32))
            task_indices.append(index)
        for i, test_input in enumerate(task['test']):
            inputs.append(np.array(test_input['input'], dtype=np.int32))
            outputs.append(np.array(solution[i], dtype=np.int32))
            task_indices.append(index)
    return inputs, outputs, task_indices, task_id_to_index

def pad_grids(grids, max_size=30, pad_value=0):
    padded_grids = []
    for grid in grids:
        rows, cols = grid.shape
        padded = np.pad(grid, ((0, max_size - rows), (0, max_size - cols)), 
                        mode='constant', constant_values=pad_value)
        padded_grids.append(padded)
    return np.stack(padded_grids)

def prepare_data(challenges, solutions):
    inputs, outputs, task_indices, task_id_to_index = process_tasks(challenges, solutions)
    print(f"\t number of samples: {len(inputs)}")
    inputs_array = pad_grids(inputs)
    outputs_array = pad_grids(outputs)
    task_indices_array = np.array(task_indices, dtype=np.int32)
    inputs_array = jnp.array(inputs_array)
    outputs_array = jnp.array(outputs_array)
    task_indices_array = jnp.array(task_indices_array)
    return inputs_array, outputs_array, task_indices_array, task_id_to_index

print("Processing train data...")
train_inputs, train_outputs, train_task_indices, task_id_to_index = prepare_data(train_challenges, train_solutions)
print(f"\t inputs shape: {train_inputs.shape}")
print(f"\t outputs shape: {train_outputs.shape}")
print(f"\t task indices shape: {train_task_indices.shape}")

print("Processing eval data...")
eval_inputs, eval_outputs, eval_task_indices, _ = prepare_data(eval_challenges, eval_solutions)
print(f"\t inputs shape: {eval_inputs.shape}")
print(f"\t outputs shape: {eval_outputs.shape}")
print(f"\t task indices shape: {eval_task_indices.shape}")

# Data Augmentation
def augment_data(inputs, outputs):
    augmented_inputs = []
    augmented_outputs = []
    
    for input_grid, output_grid in zip(inputs, outputs):
        # Original data
        augmented_inputs.append(input_grid)
        augmented_outputs.append(output_grid)
        
        # Horizontal flip
        augmented_inputs.append(np.fliplr(input_grid))
        augmented_outputs.append(np.fliplr(output_grid))
        
        # Vertical flip
        augmented_inputs.append(np.flipud(input_grid))
        augmented_outputs.append(np.flipud(output_grid))
        
        # 90-degree rotation
        augmented_inputs.append(np.rot90(input_grid))
        augmented_outputs.append(np.rot90(output_grid))
        
        # 180-degree rotation
        augmented_inputs.append(np.rot90(input_grid, 2))
        augmented_outputs.append(np.rot90(output_grid, 2))
        
        # Add noise (small random perturbations)
        noisy_input = input_grid + np.random.randint(-1, 2, input_grid.shape)
        noisy_input = np.clip(noisy_input, 0, 9)
        augmented_inputs.append(noisy_input)
        augmented_outputs.append(output_grid)
    
    return np.array(augmented_inputs), np.array(augmented_outputs)

# Apply augmentation to training data
augmented_train_inputs, augmented_train_outputs = augment_data(train_inputs, train_outputs)

# Model Architecture
class ImprovedCAX(nnx.Module):
    features: int = 64
    num_layers: int = 2

    @nnx.compact
    def __call__(self, x, task_embedding):
        # Expand task embedding to match grid dimensions
        task_embedding = jnp.tile(task_embedding[:, None, None, :], (1, x.shape[1], x.shape[2], 1))
        
        # Concatenate input grid with task embedding
        x = jnp.concatenate([x[..., None], task_embedding], axis=-1)
        
        # Perceive module
        for _ in range(self.num_layers):
            x = nnx.Conv(features=self.features, kernel_size=(3, 3), padding='SAME')(x)
            x = jax.nn.relu(x)
        
        # Update module
        x = nnx.Conv(features=self.features, kernel_size=(1, 1))(x)
        x = jax.nn.relu(x)
        x = nnx.Conv(features=1, kernel_size=(1, 1))(x)
        
        return x.squeeze(-1)

# Training Process
def train_step(state, batch):
    def loss_fn(params):
        inputs, outputs, task_indices = batch
        task_embeddings = state.params['task_embeddings'][task_indices]
        predictions = state.apply_fn({'params': params}, inputs, task_embeddings)
        loss = jnp.mean((predictions - outputs) ** 2)
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Evaluation and Submission
def evaluate(state, inputs, outputs, task_indices):
    task_embeddings = state.params['task_embeddings'][task_indices]
    predictions = state.apply_fn({'params': state.params}, inputs, task_embeddings)
    accuracy = jnp.mean(jnp.all(predictions.round() == outputs, axis=(1, 2)))
    return accuracy

def prepare_submission(state, test_challenges_path):
    with open(test_challenges_path, 'r') as f:
        test_challenges = json.load(f)
    
    submission = {}
    for task_id, task in test_challenges.items():
        task_index = task_id_to_index[task_id]
        task_embedding = state.params['task_embeddings'][task_index]
        
        task_submission = []
        for test_input in task['test']:
            input_grid = jnp.array(pad_grids([np.array(test_input['input'], dtype=np.int32)]))
            predictions = state.apply_fn({'params': state.params}, input_grid, task_embedding[None, ...])
            rounded_predictions = predictions.round().astype(int).tolist()
            
            # Trim padding
            height, width = np.array(test_input['input']).shape
            trimmed_prediction = [row[:width] for row in rounded_predictions[0][:height]]
            
            task_submission.append({
                "attempt_1": trimmed_prediction,
                "attempt_2": trimmed_prediction  # You might want to generate a different second attempt
            })
        
        submission[task_id] = task_submission
    
    with open('submission.json', 'w') as f:
        json.dump(submission, f)

# Additional Improvements

# Implement curriculum learning
def curriculum_learning_schedule(epoch):
    if epoch < 10:
        return 0.5  # Start with easier tasks
    elif epoch < 20:
        return 0.75  # Gradually increase difficulty
    else:
        return 1.0  # Full dataset

# Implement task embedding initialization
def initialize_task_embeddings(num_tasks, embedding_dim=64):
    return jax.random.normal(jax.random.PRNGKey(0), (num_tasks, embedding_dim))

# Main training loop
def train_model():
    model = ImprovedCAX()
    task_embeddings = initialize_task_embeddings(len(task_id_to_index))
    
    learning_rate = optax.exponential_decay(
        init_value=1e-3, 
        transition_steps=1000,
        decay_rate=0.9
    )
    optimizer = optax.adam(learning_rate)
    
    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=model.init(jax.random.PRNGKey(0), jnp.zeros((1, 30, 30)), jnp.zeros((1, 64))),
        tx=optimizer
    )
    state = state.replace(params={**state.params, 'task_embeddings': task_embeddings})
    
    num_epochs = 100
    batch_size = 32
    
    best_eval_accuracy = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Apply curriculum learning
        difficulty = curriculum_learning_schedule(epoch)
        num_samples = int(len(augmented_train_inputs) * difficulty)
        
        # Shuffle and batch data
        permutation = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
        epoch_loss = 0
        for i in range(0, num_samples, batch_size):
            batch_indices = permutation[i:i+batch_size]
            batch = (
                augmented_train_inputs[batch_indices],
                augmented_train_outputs[batch_indices],
                train_task_indices[batch_indices]
            )
            state, loss = train_step(state, batch)
            epoch_loss += loss
        
        # Evaluate on validation set
        eval_accuracy = evaluate(state, eval_inputs, eval_outputs, eval_task_indices)
        print(f"Epoch {epoch}, Loss: {epoch_loss / (num_samples // batch_size)}, Eval Accuracy: {eval_accuracy}")
        
        # Early stopping
        if eval_accuracy > best_eval_accuracy:
            best_eval_accuracy = eval_accuracy
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
    
    return state

# Run training
final_state = train_model()

# Prepare submission
test_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
prepare_submission(final_state, test_challenges_path)