In [None]:
from functools import partial
import json
import os
from datetime import datetime
import io

import jax
import jax.numpy as jnp
import numpy as np
import wandb
from PIL import Image

import flax
from flax import linen as nn
import optax

from arckit import draw_task

learning_rate = 1e-3
batch_size = 64
num_epochs = 8
print_every = 100
augmentations = ['flip_horizontal', 'flip_vertical', 'rotate_90', 'rotate_180', 'rotate_270']
max_size = 30
pad_value = 0
num_visualization_samples = 3  # Number of samples to visualize during training and submission

compute_backend = os.environ.get("COMPUTE_BACKEND", "unk")
morph = os.environ["MORPH"]
wandb_entity = os.environ["WANDB_ENTITY"]
wandb_project = os.environ["WANDB_PROJECT"]

# Initialize WandB
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
run_name = f"{compute_backend}.{morph}.{current_datetime}"
wandb.login()
wandb.init(
    entity=wandb_entity,
    project=wandb_project,
    name=run_name,
    config={
        "morph": morph,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "augmentations": augmentations,
        "max_size": max_size,
        "pad_value": pad_value,
    }
)

def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

# Load datasets
train_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json')
train_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json')
eval_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json')
eval_solutions = load_data('/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json')

def process_tasks(challenges, solutions):
    data = []
    task_id_to_index = {}
    for index, task_id in enumerate(challenges.keys()):
        task_id_to_index[task_id] = index
        task = challenges[task_id]
        solution = solutions[task_id]
        for pair in task['train']:
            input_grid = np.array(pair['input'], dtype=np.int32)
            output_grid = np.array(pair['output'], dtype=np.int32)
            data.append((input_grid, output_grid, index))
        for i, test_input in enumerate(task['test']):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            output_grid = np.array(solution[i], dtype=np.int32)
            data.append((input_grid, output_grid, index))
    return data, task_id_to_index

train_data, task_id_to_index = process_tasks(train_challenges, train_solutions)
eval_data, _ = process_tasks(eval_challenges, eval_solutions)

def pad_grid(grid, max_size=max_size, pad_value=pad_value):
    padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
    rows, cols = grid.shape
    start_row = (max_size - rows) // 2
    start_col = (max_size - cols) // 2
    padded[start_row:start_row+rows, start_col:start_col+cols] = grid
    return padded

def unpad_grid(grid, original_shape, max_size=max_size):
    rows, cols = original_shape
    start_row = (max_size - rows) // 2
    start_col = (max_size - cols) // 2
    return grid[start_row:start_row + rows, start_col:start_col + cols]

@jax.jit
def flip_horizontal(grid):
    return jnp.fliplr(grid)

@jax.jit
def flip_vertical(grid):
    return jnp.flipud(grid)

@partial(jax.jit, static_argnums=(1,))
def rotate(grid, k):
    return jnp.rot90(grid, k=k)

AUGMENTATIONS = {
    'flip_horizontal': flip_horizontal,
    'flip_vertical': flip_vertical,
    'rotate_90': partial(rotate, k=1),
    'rotate_180': partial(rotate, k=2),
    'rotate_270': partial(rotate, k=3),
}

def apply_random_augmentations_batch(input_grids, output_grids, rng_keys, augmentations):
    num_augs = len(augmentations)

    def augment_sample(input_grid, output_grid, rng_key):
        apply_aug = jax.random.bernoulli(rng_key, p=0.5, shape=(num_augs,))
        for i, aug_name in enumerate(augmentations):
            aug_func = AUGMENTATIONS[aug_name]
            input_grid, output_grid = jax.lax.cond(
                apply_aug[i],
                lambda grids: (aug_func(grids[0]), aug_func(grids[1])),
                lambda grids: grids,
                (input_grid, output_grid)
            )
        return input_grid, output_grid

    v_augment_sample = jax.vmap(augment_sample, in_axes=(0, 0, 0), out_axes=(0, 0))
    augmented_inputs, augmented_outputs = v_augment_sample(input_grids, output_grids, rng_keys)
    return augmented_inputs, augmented_outputs

def data_generator(data, batch_size, augmentations=None, max_size=max_size, pad_value=pad_value, shuffle=True):
    num_samples = len(data)
    indices = np.arange(num_samples)
    epoch = 0
    while True:
        if shuffle:
            rng = np.random.default_rng(seed=epoch)
            rng.shuffle(indices)
        for start_idx in range(0, num_samples, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            batch_inputs = []
            batch_outputs = []
            batch_task_indices = []
            for idx in excerpt:
                input_grid, output_grid, task_index = data[idx]
                input_padded = pad_grid(input_grid, max_size, pad_value)
                output_padded = pad_grid(output_grid, max_size, pad_value)
                batch_inputs.append(input_padded)
                batch_outputs.append(output_padded)
                batch_task_indices.append(task_index)
            batch_inputs = jnp.stack(batch_inputs)
            batch_outputs = jnp.stack(batch_outputs)
            batch_task_indices = jnp.array(batch_task_indices, dtype=jnp.int32)
            if augmentations:
                rng_key = jax.random.PRNGKey(epoch)
                rng_key = jax.random.fold_in(rng_key, start_idx)
                rng_keys = jax.random.split(rng_key, len(batch_inputs))
                batch_inputs, batch_outputs = apply_random_augmentations_batch(batch_inputs, batch_outputs, rng_keys, augmentations)
            yield batch_inputs, batch_outputs, batch_task_indices
        epoch += 1

def get_eval_batches(eval_data, batch_size, max_size=max_size, pad_value=pad_value):
    num_samples = len(eval_data)
    indices = np.arange(num_samples)
    for start_idx in range(0, num_samples, batch_size):
        excerpt = indices[start_idx:start_idx + batch_size]
        batch_inputs = []
        batch_outputs = []
        batch_task_indices = []
        for idx in excerpt:
            input_grid, output_grid, task_index = eval_data[idx]
            input_padded = pad_grid(input_grid, max_size, pad_value)
            output_padded = pad_grid(output_grid, max_size, pad_value)
            batch_inputs.append(input_padded)
            batch_outputs.append(output_padded)
            batch_task_indices.append(task_index)
        batch_inputs = jnp.stack(batch_inputs)
        batch_outputs = jnp.stack(batch_outputs)
        batch_task_indices = jnp.array(batch_task_indices, dtype=jnp.int32)
        yield batch_inputs, batch_outputs, batch_task_indices

class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Add channel dimension, shape (batch_size, 30, 30, 1)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=10, kernel_size=(1, 1))(x)  # Output logits for each class
        return x  # Shape: (batch_size, 30, 30, 10)

model = Model()

def compute_loss(logits, labels):
    logits = logits.reshape(-1, 10)
    labels = labels.reshape(-1)
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
    return loss

# Initialize model parameters and optimizer state
params = model.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_size, max_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

@jax.jit
def train_step(params, batch_inputs, batch_outputs, opt_state):
    def loss_fn(params):
        logits = model.apply(params, batch_inputs)
        loss = compute_loss(logits, batch_outputs)
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def eval_step(params, batch_inputs, batch_outputs):
    logits = model.apply(params, batch_inputs)
    loss = compute_loss(logits, batch_outputs)
    preds = jnp.argmax(logits, axis=-1)
    # accuracy = jnp.mean(preds == batch_outputs)
    # accuracy without counting zeroes
    accuracy = jnp.mean(jnp.where(batch_outputs == 0, 0, preds == batch_outputs))
    return loss, accuracy

def evaluate(params, eval_data, batch_size):
    eval_batches = get_eval_batches(eval_data, batch_size)
    losses = []
    accuracies = []
    for batch_inputs, batch_outputs, batch_task_indices in eval_batches:
        loss, accuracy = eval_step(params, batch_inputs, batch_outputs)
        losses.append(loss)
        accuracies.append(accuracy)
    avg_loss = np.mean(losses)
    avg_accuracy = np.mean(accuracies)
    return avg_loss, avg_accuracy

def wrap_draw_task(task_dict, include_test=False):
    class TaskVisualization:
        def __init__(self, task_id, train, test):
            # Ensure that input and output grids are NumPy arrays
            self.id = task_id
            self.train = [(np.array(input_grid), np.array(output_grid)) for input_grid, output_grid in train]
            self.test = [(np.array(input_grid), np.array(output_grid)) for input_grid, output_grid in test] if test else []

    task_obj = TaskVisualization(
        task_id=task_dict['id'],
        train=task_dict['train'],
        test=task_dict.get('test', [])
    )
    return draw_task(task_obj, include_test=include_test)

# Create data generators
train_generator = data_generator(train_data, batch_size, augmentations=augmentations)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_losses = []
    steps_per_epoch = len(train_data) // batch_size
    for step in range(steps_per_epoch):
        batch_inputs, batch_outputs, batch_task_indices = next(train_generator)
        params, opt_state, loss = train_step(params, batch_inputs, batch_outputs, opt_state)
        train_losses.append(loss)
        
        if (step + 1) % print_every == 0:
            current_step = epoch * steps_per_epoch + step + 1
            print(f"Step {step + 1}/{steps_per_epoch}, Loss: {loss:.4f}")
            wandb.log({"train_loss": loss.item()}, step=current_step)
    
    avg_train_loss = np.mean(train_losses)
    print(f"Epoch {epoch + 1} completed. Average training loss: {avg_train_loss:.4f}")
    wandb.log({"avg_train_loss": avg_train_loss}, step=(epoch + 1) * steps_per_epoch)
    
    # Evaluation
    avg_eval_loss, avg_eval_accuracy = evaluate(params, eval_data, batch_size)
    print(f"Validation loss: {avg_eval_loss:.4f}, Validation accuracy: {avg_eval_accuracy:.4f}")
    wandb.log({
        "val_loss": avg_eval_loss,
        "val_accuracy": avg_eval_accuracy
    }, step=(epoch + 1) * steps_per_epoch)
    
    # Visualization: Log a sample of evaluation predictions
    sample_generator = data_generator(eval_data, batch_size, shuffle=False)
    sample_batch_inputs, sample_batch_outputs, _ = next(sample_generator)
    logits = model.apply(params, sample_batch_inputs)
    preds = jnp.argmax(logits, axis=-1)
    
    for i in range(min(num_visualization_samples, batch_size)):
        input_grid = np.array(sample_batch_inputs[i])
        true_output = np.array(sample_batch_outputs[i])
        predicted_output = np.array(preds[i])
        
        task_visual = {
            'id': f'Epoch{epoch+1}_Sample{i+1}',
            'train': [(input_grid.tolist(), true_output.tolist())],
            'test': []
        }
        
        # Use the wrap_draw_task function instead of calling draw_task directly
        drawing = wrap_draw_task(task_visual, include_test=False)
        img_buffer = io.BytesIO()
        drawing.save_png(img_buffer)
        img_buffer.seek(0)
        img_buffer.seek(0)  # Ensure we're at the start of the buffer
        image = Image.open(img_buffer)  # Use PIL to open the image from the BytesIO buffer
        wandb.log({f"Evaluation_Sample_{i+1}": wandb.Image(image)}, step=(epoch + 1) * steps_per_epoch)


def process_test_tasks(challenges):
    data = []
    for task_id in challenges.keys():
        task = challenges[task_id]
        for i, test_input in enumerate(task['test']):
            input_grid = np.array(test_input['input'], dtype=np.int32)
            data.append((task_id, i, input_grid))
    return data

# Load and process test challenges
test_challenges = load_data('/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json')
test_data = process_test_tasks(test_challenges)

def generate_predictions(params, test_data, batch_size):
    predictions = {}
    visualization_samples = []  # To store samples for visualization
    for i in range(0, len(test_data), batch_size):
        batch = test_data[i:i+batch_size]
        batch_inputs = []
        task_ids = []
        input_indices = []
        original_shapes = []
        for task_id, input_index, input_grid in batch:
            original_shape = input_grid.shape  # Store original shape before padding
            input_padded = pad_grid(input_grid, max_size, pad_value)
            batch_inputs.append(input_padded)
            task_ids.append(task_id)
            input_indices.append(input_index)
            original_shapes.append(original_shape)
        batch_inputs = jnp.stack(batch_inputs)
        logits = model.apply(params, batch_inputs)
        preds = jnp.argmax(logits, axis=-1)
        preds = np.array(preds, dtype=int)
        
        for idx in range(len(batch)):
            task_id = task_ids[idx]
            input_index = input_indices[idx]
            pred_grid = preds[idx]
            unpadded_pred = unpad_grid(pred_grid, original_shapes[idx])  # Unpad the prediction
            pred_grid_list = unpadded_pred.tolist()
            attempt = {"attempt_1": pred_grid_list, "attempt_2": pred_grid_list}
            if task_id not in predictions:
                predictions[task_id] = []
            while len(predictions[task_id]) <= input_index:
                predictions[task_id].append({})
            predictions[task_id][input_index] = attempt
            
            if len(visualization_samples) < num_visualization_samples:
                original_input = np.array(batch[idx][2])
                visualization_samples.append((original_input, unpadded_pred))
                
    return predictions, visualization_samples

predictions, visualization_samples = generate_predictions(params, test_data, batch_size)

# Log submission visualizations
for i, (input_grid, pred_grid) in enumerate(visualization_samples):
    task_visual = {
        'id': f'Submission_Sample_{i+1}',
        'train': [(input_grid.tolist(), pred_grid.tolist())],
        'test': []
    }
    # Use the wrap_draw_task function instead of calling draw_task directly
    drawing = wrap_draw_task(task_visual, include_test=False)
    img_buffer = io.BytesIO()
    drawing.save_png(img_buffer)
    img_buffer.seek(0)  # Ensure we're at the start of the buffer
    image = Image.open(img_buffer)  # Use PIL to open the image from the BytesIO buffer
    wandb.log({f"Submission_Sample_{i+1}": wandb.Image(image)})

# Save predictions to submission.json
# TODO: Uncomment when cleaning for submission during export
# with open('submission.json', 'w') as f:
#     json.dump(predictions, f)

# Finish WandB run
wandb.finish()