In [None]:
from functools import partial
import json
import os
from datetime import datetime
import io

import jax
import jax.numpy as jnp
import numpy as np
import wandb
from PIL import Image
import drawsvg

import flax
from flax import linen as nn
import optax

import arckit.vis as vis
from arckit import draw_task

learning_rate = 1e-3
batch_size = 8  # Reduced batch size due to complexity of handling multiple pairs
num_epochs = 8
print_every = 100
augmentations = ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']
max_size = 30
pad_value = 0
num_visualization_samples = 3

# Hyperparameters for the models
conv_features = 32
rnn_hidden_size = 64

# Initialize WandB
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
run_name = f"contextual_rnn_convnet.{current_datetime}"
wandb.login()
wandb.init(
    project="arc-puzzle-solver",
    name=run_name,
    config={
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "augmentations": augmentations,
        "max_size": max_size,
        "pad_value": pad_value,
        "conv_features": conv_features,
        "rnn_hidden_size": rnn_hidden_size,
    }
)

def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

# Load datasets
train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

def process_tasks(challenges, solutions):
    data = []
    for task_id in challenges.keys():
        task = challenges[task_id]
        solution = solutions[task_id]
        train_pairs = []
        for pair in task['train']:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            train_pairs.append((input_grid, output_grid))
        test_inputs = [np.array(test['input'], dtype=np.int32) for test in task['test']]
        test_outputs = [np.array(output, dtype=np.int32) for output in solution]
        data.append((train_pairs, test_inputs, test_outputs))
    return data

train_data = process_tasks(train_challenges, train_solutions)
eval_data = process_tasks(eval_challenges, eval_solutions)

def pad_grid(grid, max_size=max_size, pad_value=pad_value):
    padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
    rows, cols = grid.shape
    start_row = (max_size - rows) // 2
    start_col = (max_size - cols) // 2
    padded[start_row:start_row+rows, start_col:start_col+cols] = grid
    return padded

def data_generator(data, batch_size, max_size=max_size, pad_value=pad_value, shuffle=True):
    num_samples = len(data)
    indices = np.arange(num_samples)
    epoch = 0
    while True:
        if shuffle:
            rng = np.random.default_rng(seed=epoch)
            rng.shuffle(indices)
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_train_pairs = []
            batch_test_inputs = []
            batch_test_outputs = []
            for idx in excerpt:
                train_pairs, test_inputs, test_outputs = data[idx]
                padded_train_pairs = [(pad_grid(inp), pad_grid(out)) for inp, out in train_pairs]
                padded_test_inputs = [pad_grid(inp) for inp in test_inputs]
                padded_test_outputs = [pad_grid(out) for out in test_outputs]
                batch_train_pairs.append(padded_train_pairs)
                batch_test_inputs.append(padded_test_inputs)
                batch_test_outputs.append(padded_test_outputs)
            yield batch_train_pairs, batch_test_inputs, batch_test_outputs
        epoch += 1

# Define the ConvNet encoder
class ConvEncoder(nn.Module):
    features: int = conv_features

    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Add channel dimension, shape (batch_size, 30, 30, 1)
        x = nn.Conv(features=self.features, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=self.features, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        return x  # Shape: (batch_size, 30, 30, features)

# Define the RNN sequence model
class RNNModel(nn.Module):
    hidden_size: int = rnn_hidden_size

    @nn.compact
    def __call__(self, train_pairs, test_inputs):
        rnn = nn.recurrent.GRUCell()
        hidden_state = jnp.zeros((train_pairs[0][0].shape[0], self.hidden_size))

        # Process each training pair
        for input_grid, output_grid in train_pairs:
            combined = jnp.concatenate([input_grid, output_grid], axis=-1)  # Concatenate along feature axis
            hidden_state, _ = rnn(hidden_state, combined)

        # Predict for each test input
        predictions = []
        for test_input in test_inputs:
            hidden_state, _ = rnn(hidden_state, test_input)
            pred = nn.Dense(features=10)(hidden_state)  # Output logits for each class
            predictions.append(pred)

        return predictions

# Initialize models
conv_encoder = ConvEncoder()
rnn_model = RNNModel()

# Initialize optimizer
params = conv_encoder.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_size, max_size)))
params = {**params, **rnn_model.init(jax.random.PRNGKey(0), [(jnp.ones((batch_size, max_size, max_size, conv_features)), jnp.ones((batch_size, max_size, max_size, conv_features)))], [jnp.ones((batch_size, max_size, max_size, conv_features))])}
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

@jax.jit
def train_step(params, batch_train_pairs, batch_test_inputs, batch_test_outputs, opt_state):
    def loss_fn(params):
        train_pairs_encoded = [(conv_encoder.apply(params, inp), conv_encoder.apply(params, out)) for inp, out in batch_train_pairs]
        test_inputs_encoded = [conv_encoder.apply(params, inp) for inp in batch_test_inputs]
        preds = rnn_model.apply(params, train_pairs_encoded, test_inputs_encoded)
        loss = 0
        for pred, true_output in zip(preds, batch_test_outputs):
            logits = pred.reshape(-1, 10)
            labels = true_output.reshape(-1)
            one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
            loss += optax.softmax_cross_entropy(logits, one_hot_labels).mean()
        return loss

    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Training loop
train_generator = data_generator(train_data, batch_size)
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_losses = []
    steps_per_epoch = len(train_data) // batch_size
    for step in range(steps_per_epoch):
        batch_train_pairs, batch_test_inputs, batch_test_outputs = next(train_generator)
        params, opt_state, loss = train_step(params, batch_train_pairs, batch_test_inputs, batch_test_outputs, opt_state)
        train_losses.append(loss)

        if (step + 1) % print_every == 0:
            current_step = epoch * steps_per_epoch + step + 1
            print(f"Step {step + 1}/{steps_per_epoch}, Loss: {loss:.4f}")
            wandb.log({"train_loss": loss.item()}, step=current_step)

    avg_train_loss = np.mean(train_losses)
    print(f"Epoch {epoch + 1} completed. Average training loss: {avg_train_loss:.4f}")
    wandb.log({"avg_train_loss": avg_train_loss}, step=(epoch + 1) * steps_per_epoch)

    # Visualization: Log a sample of training predictions
    sample_train_pairs, sample_test_inputs, _ = next(train_generator)
    train_pairs_encoded = [(conv_encoder.apply(params, inp), conv_encoder.apply(params, out)) for inp, out in sample_train_pairs]
    test_inputs_encoded = [conv_encoder.apply(params, inp) for inp in sample_test_inputs]
    preds = rnn_model.apply(params, train_pairs_encoded, test_inputs_encoded)
    
    for i in range(min(num_visualization_samples, len(sample_test_inputs))):
        input_grid = sample_test_inputs[i]
        predicted_output = jnp.argmax(preds[i], axis=-1)
        
        task_visual = {
            'id': f'Epoch{epoch+1}_Sample{i+1}',
            'train': sample_train_pairs[i],
            'test': [(input_grid.tolist(), predicted_output.tolist())]
        }
        drawing = vis.draw_task(task_visual, width=10, height=6)
        img_buffer = io.BytesIO()
        vis.output_drawing(drawing, img_buffer)
        img_buffer.seek(0)
        image = Image.open(img_buffer)
        wandb.log({f"Training_Sample_{i+1}": wandb.Image(image)}, step=(epoch + 1) * steps_per_epoch)

# Finish WandB run
wandb.finish()