In [None]:
from functools import partial
import json
import os
from datetime import datetime
import io

import jax
import jax.numpy as jnp
import numpy as np
import wandb
from PIL import Image
import drawsvg

import flax
from flax import linen as nn
import optax

import arckit.vis as vis
from arckit import draw_task

learning_rate = 1e-3
batch_size = 8  # Reduced batch size due to complexity of handling multiple pairs
num_epochs = 8
print_every = 100
augmentations = ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']
max_size = 30
pad_value = 0
num_visualization_samples = 3
aug_prob = 0.1
conv_features = 32
rnn_hidden_size = 64

# Initialize WandB
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
run_name = f"contextual_rnn_convnet.{current_datetime}"
wandb.login()
wandb.init(
    project="arc-puzzle-solver",
    name=run_name,
    config={
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "augmentations": augmentations,
        "max_size": max_size,
        "pad_value": pad_value,
        "conv_features": conv_features,
        "rnn_hidden_size": rnn_hidden_size,
    }
)

def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

# Load datasets
train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

def process_tasks(challenges, solutions):
    data = []
    for task_id in challenges.keys():
        task = challenges[task_id]
        solution = solutions[task_id]
        train_inputs = [np.array(pair[0], dtype=np.uint8) for pair in task['train']]
        train_outputs = [np.array(pair[1], dtype=np.uint8) for pair in task['train']]
        test_inputs = [np.array(test['input'], dtype=np.uint8) for test in task['test']]
        test_outputs = [np.array(output, dtype=np.uint8) for output in solution]
        data.append((train_inputs, train_outputs, test_inputs, test_outputs))
    return data

train_data = process_tasks(train_challenges, train_solutions)
eval_data = process_tasks(eval_challenges, eval_solutions)

def pad_grid(grid, max_size=max_size, pad_value=pad_value):
    padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
    rows, cols = grid.shape
    start_row = (max_size - rows) // 2
    start_col = (max_size - cols) // 2
    padded[start_row:start_row+rows, start_col:start_col+cols] = grid
    return padded

def unpad_grid(grid, original_shape, max_size=max_size):
    rows, cols = original_shape
    start_row = (max_size - rows) // 2
    start_col = (max_size - cols) // 2
    return grid[start_row:start_row + rows, start_col:start_col + cols]

@jax.jit
def flip_horizontal(grid):
    return jnp.fliplr(grid)

@jax.jit
def flip_vertical(grid):
    return jnp.flipud(grid)

@partial(jax.jit, static_argnums=(1,))
def rotate(grid, k):
    return jnp.rot90(grid, k=k)

AUGMENTATIONS = {
    'flip_horizontal': flip_horizontal,
    'flip_vertical': flip_vertical,
    'rotate_90': partial(rotate, k=1),
    'rotate_180': partial(rotate, k=2),
    'rotate_270': partial(rotate, k=3),
}

def augment_sample(train_inputs, train_outputs, test_inputs, test_outputs, rng_key):
    apply_aug = jax.random.bernoulli(rng_key, p=aug_prob, shape=(len(augmentations),))
    for i, aug_name in enumerate(augmentations):
        aug_func = AUGMENTATIONS[aug_name]
        train_inputs, train_outputs, test_inputs, test_outputs = jax.lax.cond(
            apply_aug[i],
            lambda x: (aug_func(x[0]), aug_func(x[1]), aug_func(x[2]), aug_func(x[3])),
            (train_inputs, train_outputs, test_inputs, test_outputs)
        )
    return train_inputs, train_outputs, test_inputs, test_outputs

def train_generator(data, batch_size, augmentations=None, max_size=max_size, pad_value=pad_value, shuffle=True):
    num_samples = len(data)
    indices = np.arange(num_samples)
    epoch = 0
    while True:
        if shuffle:
            rng = np.random.default_rng(seed=epoch)
            rng.shuffle(indices)
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_train_inputs, batch_train_outputs, batch_test_inputs, batch_test_outputs = [], [], [], []
            for idx in excerpt:
                train_inputs, train_outputs, test_inputs, test_outputs = data[idx]
                train_inputs_pad = [pad_grid(inp, max_size=max_size, pad_value=pad_value) for inp in train_inputs]
                train_outputs_pad = [pad_grid(out, max_size=max_size, pad_value=pad_value) for out in train_outputs]                
                test_inputs_pad = [pad_grid(inp, max_size=max_size, pad_value=pad_value) for inp in test_inputs]
                padded_test_outputs = [pad_grid(out, max_size=max_size, pad_value=pad_value) for out in test_outputs]
                if augmentations:
                    rng_key = jax.random.PRNGKey(epoch)
                    rng_key = jax.random.fold_in(rng_key, start_idx)
                    rng_keys = jax.random.split(rng_key, len(augmentations))
                    v_aug = jax.vmap(augment_sample, in_axes=(0, 0, 0, 0, 0))
                    train_inputs_pad, train_outputs_pad, test_inputs_pad, padded_test_outputs = v_aug(train_inputs_pad, train_outputs_pad, test_inputs_pad, padded_test_outputs, rng_keys)
                batch_train_inputs.append(train_inputs_pad)
                batch_train_outputs.append(train_outputs_pad)
                batch_test_inputs.append(test_inputs_pad)
                batch_test_outputs.append(padded_test_outputs)
            yield batch_train_inputs, batch_train_outputs, batch_test_inputs, batch_test_outputs
        epoch += 1

def eval_generator(data, batch_size, max_size=max_size, pad_value=pad_value):
    num_samples = len(data)
    indices = np.arange(num_samples)
    while True:
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_train_inputs, batch_train_outputs, batch_test_inputs, batch_test_outputs = [], [], [], []
            for idx in excerpt:
                train_inputs, train_outputs, test_inputs, test_outputs = data[idx]
                train_inputs_pad = [pad_grid(inp, max_size=max_size, pad_value=pad_value) for inp in train_inputs]
                train_outputs_pad = [pad_grid(out, max_size=max_size, pad_value=pad_value) for out in train_outputs]
                test_inputs_pad = [pad_grid(inp, max_size=max_size, pad_value=pad_value) for inp in test_inputs]
                padded_test_outputs = [pad_grid(out, max_size=max_size, pad_value=pad_value) for out in test_outputs]
                batch_train_inputs.append(train_inputs_pad)
                batch_train_outputs.append(train_outputs_pad)
                batch_test_inputs.append(test_inputs_pad)
                batch_test_outputs.append(padded_test_outputs)
            yield batch_train_inputs, batch_train_outputs, batch_test_inputs, batch_test_outputs

# Define the ConvNet encoder
class ConvEncoder(nn.Module):
    # TODO

# Define the RNN sequence model
class RNNModel(nn.Module):
    # TODO
    # process alternating training inputs and outputs, then test inputs, and predict test outputs

# Initialize models
conv_encoder = ConvEncoder()
rnn_model = RNNModel()

# Initialize optimizer
params = conv_encoder.init(#TODO
params = {**params, **rnn_model.init( #TODO
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

def train_step():
# TODO

def eval_step():
# TODO

# Training loop
train_generator = train_generator(train_data, batch_size)
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_losses = []
    steps_per_epoch = len(train_data) // batch_size
    for step in range(steps_per_epoch):
        batch_train_pairs, batch_train_outputs, batch_test_inputs, batch_test_outputs = next(train_generator)
        params, opt_state, loss = train_step(params, opt_state, batch_train_pairs, batch_train_outputs, batch_test_inputs, batch_test_outputs)
        train_losses.append(loss)

        if (step + 1) % print_every == 0:
            current_step = epoch * steps_per_epoch + step + 1
            print(f"Step {step + 1}/{steps_per_epoch}, Loss: {loss:.4f}")
            wandb.log({"train_loss": loss.item()}, step=current_step)

    avg_train_loss = np.mean(train_losses)
    print(f"Epoch {epoch + 1} completed. Average training loss: {avg_train_loss:.4f}")
    wandb.log({"avg_train_loss": avg_train_loss}, step=(epoch + 1) * steps_per_epoch)

    # Evaluation
    for eval_data in eval_generator(eval_data, batch_size):
        eval_losses = []
        eval_accuracies = []
        for batch_train_pairs, batch_train_outputs, batch_test_inputs, batch_test_outputs in eval_data:
            eval_loss, eval_accuracy = eval_step(params, batch_train_pairs, batch_train_outputs, batch_test_inputs, batch_test_outputs)
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_accuracy)
        avg_eval_loss = np.mean(eval_losses)
        avg_eval_accuracy = np.mean(eval_accuracies)
    print(f"Validation loss: {avg_eval_loss:.4f}, Validation accuracy: {avg_eval_accuracy:.4f}")
    wandb.log({
        "val_loss": avg_eval_loss,
        "val_accuracy": avg_eval_accuracy
    }, step=(epoch + 1) * steps_per_epoch)

# Finish WandB run
wandb.finish()