In [None]:
!pip install wandb
!wandb login

In [None]:
from functools import partial
import json
import os
from datetime import datetime

import jax
import jaxlib
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
import wandb

import flax
from flax import linen as nn
import optax

for pkg in [jax, jaxlib, flax, optax]:
    print(pkg.__name__, pkg.__version__)

# Environment variables
wandb_entity = os.environ["WANDB_ENTITY"]
wandb_project = os.environ["WANDB_PROJECT"]
morph = os.environ["MORPH"]
morph_output_dir = os.environ["MORPH_OUTPUT_DIR"]

# Initialize WandB
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"{morph}_{current_datetime}"
wandb.init(
    entity=wandb_entity,
    project=wandb_project,
    name=run_name,
    config={
        "morph": morph,
        "learning_rate": 1e-3,
        "batch_size": 32,
        "num_epochs": 2,
        "augmentations": ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']
    }
)

def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

# Process tasks
def process_tasks(challenges, solutions):
    data = []
    task_id_to_index = {}
    for index, task_id in enumerate(challenges.keys()):
        task_id_to_index[task_id] = index
        task = challenges[task_id]
        solution = solutions[task_id]
        for pair in task['train']:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            data.append((input_grid, output_grid, index))
        for i, test_input in enumerate(task['test']):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            output_grid = np.array(solution[i], dtype=np.int32)
            data.append((input_grid, output_grid, index))
    return data, task_id_to_index

# Pad grid function
def pad_grid(grid, max_size=30, pad_value=0, center_offset=(0, 0)):
    padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
    rows, cols = grid.shape
    offset_row = (max_size - rows) // 2 + center_offset[0]
    offset_col = (max_size - cols) // 2 + center_offset[1]
    offset_row = np.clip(offset_row, 0, max_size - rows)
    offset_col = np.clip(offset_col, 0, max_size - cols)
    padded[offset_row:offset_row+rows, offset_col:offset_col+cols] = grid
    return padded

# Augmentation functions using JAX
@jax.jit
def flip_horizontal(grid):
    return jnp.fliplr(grid)

@jax.jit
def flip_vertical(grid):
    return jnp.flipud(grid)

@partial(jax.jit, static_argnums=(1,))
def rotate(grid, k):
    return jnp.rot90(grid, k=k)

AUGMENTATIONS = {
    'flip_horizontal': flip_horizontal,
    'flip_vertical': flip_vertical,
    'rotate_90': partial(rotate, k=1),
    'rotate_180': partial(rotate, k=2),
    'rotate_270': partial(rotate, k=3),
}

# Apply random augmentations to a batch
def apply_random_augmentations_batch(input_grids, output_grids, rng_keys, augmentations):
    num_augs = len(augmentations)

    def augment_sample(input_grid, output_grid, rng_key):
        apply_aug = jax.random.bernoulli(rng_key, p=0.5, shape=(num_augs,))
        for i, aug_name in enumerate(augmentations):
            aug_func = AUGMENTATIONS[aug_name]
            input_grid, output_grid = jax.lax.cond(
                apply_aug[i],
                lambda grids: (aug_func(grids[0]), aug_func(grids[1])),
                lambda grids: grids,
                (input_grid, output_grid)
            )
        return input_grid, output_grid

    # Vectorize over the batch dimension
    v_augment_sample = jax.vmap(augment_sample, in_axes=(0, 0, 0), out_axes=(0, 0))
    augmented_inputs, augmented_outputs = v_augment_sample(input_grids, output_grids, rng_keys)
    return augmented_inputs, augmented_outputs

# Data generator
def data_generator(data, batch_size, augmentations=None, max_size=30, pad_value=0, center_offset=(0, 0), shuffle=True):
    num_samples = len(data)
    indices = np.arange(num_samples)
    epoch = 0
    while True:
        if shuffle:
            rng = np.random.default_rng(seed=epoch)
            rng.shuffle(indices)
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_inputs = []
            batch_outputs = []
            batch_task_indices = []
            for idx in excerpt:
                input_grid, output_grid, task_index = data[idx]
                input_padded = pad_grid(input_grid, max_size, pad_value, center_offset)
                output_padded = pad_grid(output_grid, max_size, pad_value, center_offset)
                batch_inputs.append(input_padded)
                batch_outputs.append(output_padded)
                batch_task_indices.append(task_index)
            batch_inputs = jnp.stack(batch_inputs)
            batch_outputs = jnp.stack(batch_outputs)
            batch_task_indices = jnp.array(batch_task_indices, dtype=jnp.int32)
            if augmentations:
                rng_key = jax.random.PRNGKey(epoch)
                rng_key = jax.random.fold_in(rng_key, start_idx)
                rng_keys = jax.random.split(rng_key, len(batch_inputs))
                batch_inputs, batch_outputs = apply_random_augmentations_batch(batch_inputs, batch_outputs, rng_keys, augmentations)
            yield batch_inputs, batch_outputs, batch_task_indices
        epoch += 1

train_data, task_id_to_index = process_tasks(train_challenges, train_solutions)
eval_data, _ = process_tasks(eval_challenges, eval_solutions)

augmentations = ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']
batch_size = 32

train_generator = data_generator(train_data, batch_size, augmentations=augmentations)
# We'll create a function to get evaluation batches
def get_eval_batches(eval_data, batch_size, max_size=30, pad_value=0, center_offset=(0, 0)):
    num_samples = len(eval_data)
    indices = np.arange(num_samples)
    for start_idx in range(0, num_samples, batch_size):
        excerpt = indices[start_idx:start_idx + batch_size]
        batch_inputs = []
        batch_outputs = []
        batch_task_indices = []
        for idx in excerpt:
            input_grid, output_grid, task_index = eval_data[idx]
            input_padded = pad_grid(input_grid, max_size, pad_value, center_offset)
            output_padded = pad_grid(output_grid, max_size, pad_value, center_offset)
            batch_inputs.append(input_padded)
            batch_outputs.append(output_padded)
            batch_task_indices.append(task_index)
        batch_inputs = jnp.stack(batch_inputs)
        batch_outputs = jnp.stack(batch_outputs)
        batch_task_indices = jnp.array(batch_task_indices, dtype=jnp.int32)
        yield batch_inputs, batch_outputs, batch_task_indices

# Define the model
class GridModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        # x is of shape (batch_size, 30, 30)
        x = x[..., None]  # Add channel dimension, shape (batch_size, 30, 30, 1)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=10, kernel_size=(1, 1))(x)  # Output logits for each class
        # x has shape (batch_size, 30, 30, 10)
        return x  # Return logits

model = GridModel()

# Define the loss function
def compute_loss(logits, labels):
    # logits: shape (batch_size, 30, 30, 10)
    # labels: shape (batch_size, 30, 30)
    logits = logits.reshape(-1, 10)
    labels = labels.reshape(-1)
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
    return loss

# Initialize model parameters and optimizer state
learning_rate = 1e-3
params = model.init(jax.random.PRNGKey(0), jnp.ones((batch_size, 30, 30)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Define a training step function
@jax.jit
def train_step(params, batch_inputs, batch_outputs, opt_state):
    def loss_fn(params):
        logits = model.apply(params, batch_inputs)
        loss = compute_loss(logits, batch_outputs)
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Define an evaluation step function
@jax.jit
def eval_step(params, batch_inputs, batch_outputs):
    logits = model.apply(params, batch_inputs)
    loss = compute_loss(logits, batch_outputs)
    preds = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(preds == batch_outputs)
    return loss, accuracy

# Evaluation function
def evaluate(params, eval_data, batch_size):
    eval_batches = get_eval_batches(eval_data, batch_size)
    losses = []
    accuracies = []
    for batch_inputs, batch_outputs, batch_task_indices in eval_batches:
        loss, accuracy = eval_step(params, batch_inputs, batch_outputs)
        losses.append(loss)
        accuracies.append(accuracy)
    avg_loss = np.mean(losses)
    avg_accuracy = np.mean(accuracies)
    return avg_loss, avg_accuracy

# Training loop
num_epochs = 2
steps_per_epoch = len(train_data) // batch_size
print_every = 100

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_losses = []
    for step in range(steps_per_epoch):
        batch_inputs, batch_outputs, batch_task_indices = next(train_generator)
        params, opt_state, loss = train_step(params, batch_inputs, batch_outputs, opt_state)
        train_losses.append(loss)
        
        # Log training loss every `print_every` steps
        if (step + 1) % print_every == 0:
            current_step = epoch * steps_per_epoch + step + 1
            print(f"Step {step + 1}/{steps_per_epoch}, Loss: {loss:.4f}")
            wandb.log({"train_loss": loss.item()}, step=current_step)
    
    avg_train_loss = np.mean(train_losses)
    print(f"Epoch {epoch + 1} completed. Average training loss: {avg_train_loss:.4f}")
    wandb.log({"avg_train_loss": avg_train_loss}, step=(epoch + 1) * steps_per_epoch)
    
    # Evaluation
    avg_eval_loss, avg_eval_accuracy = evaluate(params, eval_data, batch_size)
    print(f"Validation loss: {avg_eval_loss:.4f}, Validation accuracy: {avg_eval_accuracy:.4f}")
    wandb.log({
        "val_loss": avg_eval_loss,
        "val_accuracy": avg_eval_accuracy
    }, step=(epoch + 1) * steps_per_epoch)

# Load test challenges
test_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json')

# Process test tasks
def process_test_tasks(challenges):
    data = []
    for task_id in challenges.keys():
        task = challenges[task_id]
        for i, test_input in enumerate(task['test']):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            data.append((task_id, i, input_grid))
    return data

test_data = process_test_tasks(test_challenges)

# Generate predictions
def generate_predictions(params, test_data, batch_size):
    predictions = {}
    for i in range(0, len(test_data), batch_size):
        batch = test_data[i:i+batch_size]
        batch_inputs = []
        task_ids = []
        input_indices = []
        for task_id, input_index, input_grid in batch:
            input_padded = pad_grid(input_grid, max_size=30, pad_value=0)
            batch_inputs.append(input_padded)
            task_ids.append(task_id)
            input_indices.append(input_index)
        batch_inputs = jnp.stack(batch_inputs)
        logits = model.apply(params, batch_inputs)
        preds = jnp.argmax(logits, axis=-1)
        preds = np.array(preds, dtype=int)
        # For each prediction, store it in the predictions dict
        for idx in range(len(batch)):
            task_id = task_ids[idx]
            input_index = input_indices[idx]
            pred_grid = preds[idx]
            pred_grid_list = pred_grid.tolist()
            # For simplicity, use the same prediction for both attempts
            attempt = {"attempt_1": pred_grid_list, "attempt_2": pred_grid_list}
            if task_id not in predictions:
                predictions[task_id] = []
            # Since tasks may have multiple test inputs, we need to maintain order
            while len(predictions[task_id]) <= input_index:
                predictions[task_id].append({})
            predictions[task_id][input_index] = attempt
    return predictions

predictions = generate_predictions(params, test_data, batch_size)

# Save predictions to submission.json
with open('submission.json', 'w') as f:
    json.dump(predictions, f)

# Finish WandB run
wandb.finish()