In [None]:
from dataclasses import dataclass
from datetime import datetime
from functools import partial
import json
import os
import uuid

import jax
import jax.numpy as jnp
import optax
import wandb

@dataclass(frozen=True)
class Hyperparams:
    seed: int = int(os.environ.get("SEED", 42))
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # ---
    learning_rate: float = 1e-3
    batch_size: int = 8
    num_epochs: int = 1
    print_every: int = 1
    # A "grid" is a rectangular matrix (list of lists) of integers between 0 and 9 (inclusive).
    # The smallest possible grid size is 1x1 and the largest is 30x30.
    pad_dim_space: int = 30
    pad_dim_time_train: int = 10  # max number of training pairs
    pad_dim_time_eval: int = 3    # max number of eval examples
    pad_value: int = 0
    num_channels: int = 10 # number of channels for the one-hot encoding (0 to 9 inclusive)
    augment_prob: float = 0.5

hp = Hyperparams()

def load_tasks(challenges_path, solutions_path, hp: Hyperparams):
    tasks = []
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    for task_id in challenges_dict.keys():
        task = challenges_dict[task_id]
        if task_id in solutions_dict:
            task['test'] = [{'input': inp['input'], 'output': out} 
                            for inp, out in zip(task['test'], solutions_dict[task_id])]
        tasks.append(task)
    return tasks

@jax.jit
def one_hot_encode(grid, num_channels):
    return jax.nn.one_hot(grid, num_channels)

@partial(jax.jit, static_argnums=(3,))
def pad_space(key, grid_in, grid_out, hp):
    max_h = jnp.minimum(jnp.maximum(grid_in.shape[0], 0 if grid_out is None else grid_out.shape[0]), hp.pad_dim_space)
    max_w = jnp.minimum(jnp.maximum(grid_in.shape[1], 0 if grid_out is None else grid_out.shape[1]), hp.pad_dim_space)
    
    start_h = jax.random.randint(key, shape=(), minval=0, maxval=hp.pad_dim_space - max_h + 1)
    start_w = jax.random.randint(key, shape=(), minval=0, maxval=hp.pad_dim_space - max_w + 1)
    
    pad_height_in = (start_h, hp.pad_dim_space - grid_in.shape[0] - start_h)
    pad_width_in = (start_w, hp.pad_dim_space - grid_in.shape[1] - start_w)
    pad_width_in = ((pad_height_in[0], pad_height_in[1]), (pad_width_in[0], pad_width_in[1]), (0, 0))
    
    grid_in_padded = jnp.pad(grid_in, pad_width_in, mode='constant', constant_values=hp.pad_value)
    
    if grid_out is not None:
        pad_height_out = (start_h, hp.pad_dim_space - grid_out.shape[0] - start_h)
        pad_width_out = (start_w, hp.pad_dim_space - grid_out.shape[1] - start_w)
        pad_width_out = ((pad_height_out[0], pad_height_out[1]), (pad_width_out[0], pad_width_out[1]), (0, 0))
        grid_out_padded = jnp.pad(grid_out, pad_width_out, mode='constant', constant_values=hp.pad_value)
    else:
        grid_out_padded = None
    
    return grid_in_padded, grid_out_padded, (start_h, start_w)

@partial(jax.jit, static_argnums=(3, 4))
def pad_time(key, grid_in, grid_out, hp, is_eval=False):
    pad_dim_time = hp.pad_dim_time_eval if is_eval else hp.pad_dim_time_train
    
    def pad_and_repeat(grid, pad_dim):
        num_repeats = pad_dim - grid.shape[0]
        indices = jax.random.randint(key, shape=(num_repeats,), minval=0, maxval=grid.shape[0])
        repeated = grid[indices]
        return jnp.concatenate([grid, repeated], axis=0)
    
    grid_in_padded = pad_and_repeat(grid_in, pad_dim_time)
    
    if grid_out is not None:
        grid_out_padded = pad_and_repeat(grid_out, pad_dim_time)
    else:
        grid_out_padded = None
    
    # Shuffle the time dimension
    perm = jax.random.permutation(key, pad_dim_time)
    grid_in_padded = grid_in_padded[perm]
    if grid_out_padded is not None:
        grid_out_padded = grid_out_padded[perm]
    
    return grid_in_padded, grid_out_padded

@partial(jax.jit, static_argnums=(2,))
def random_augment(key, grid_in, hp):
    def flip_lr(x):
        return jnp.fliplr(x)
    
    def flip_ud(x):
        return jnp.flipud(x)
    
    augmentations = [flip_lr, flip_ud]
    
    keys = jax.random.split(key, len(augmentations) + 1)
    apply = jax.random.bernoulli(keys[0], hp.augment_prob, shape=(len(augmentations),))
    
    for i, augmentation in enumerate(augmentations):
        grid_in = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_in)
    
    # Shuffle the channel dimension
    perm = jax.random.permutation(keys[-1], hp.num_channels)
    grid_in = grid_in[..., perm]
    
    return grid_in

@partial(jax.jit, static_argnums=(2,))
def process_single_task(key, task, hp):
    keys = jax.random.split(key, 5)
    
    # Unpack the task
    train_pairs, test_inputs = task['train'], task['test']
    
    # Process training data
    train_in = [jnp.array(pair['input'], dtype=jnp.int32) for pair in train_pairs]
    train_out = [jnp.array(pair['output'], dtype=jnp.int32) for pair in train_pairs]
    
    # One-hot encode and pad training data
    train_in_padded = []
    train_out_padded = []
    for t_in, t_out, k in zip(train_in, train_out, jax.random.split(keys[0], len(train_in))):
        t_in_one_hot = jax.nn.one_hot(t_in, hp.num_channels)
        t_out_one_hot = jax.nn.one_hot(t_out, hp.num_channels)
        t_in_pad, t_out_pad, _ = pad_space(k, t_in_one_hot, t_out_one_hot, hp)
        train_in_padded.append(t_in_pad)
        train_out_padded.append(t_out_pad)
    
    # Stack and pad time for training data
    train_in_padded = jnp.stack(train_in_padded)
    train_out_padded = jnp.stack(train_out_padded)
    train_in_padded, train_out_padded = pad_time(keys[1], train_in_padded, train_out_padded, hp)
    
    # Process evaluation data
    eval_in = [jnp.array(test_input['input'], dtype=jnp.int32) for test_input in test_inputs]
    
    # One-hot encode and pad evaluation data
    eval_in_padded = []
    for t_in, k in zip(eval_in, jax.random.split(keys[2], len(eval_in))):
        t_in_one_hot = jax.nn.one_hot(t_in, hp.num_channels)
        t_in_pad, _, _ = pad_space(k, t_in_one_hot, None, hp)
        eval_in_padded.append(t_in_pad)
    
    # Stack and pad time for evaluation data
    eval_in_padded = jnp.stack(eval_in_padded)
    eval_in_padded, _ = pad_time(keys[3], eval_in_padded, None, hp, is_eval=True)
    
    # Apply random augmentations
    train_in_padded = jax.vmap(random_augment, in_axes=(0, 0, None))(
        jax.random.split(keys[4], train_in_padded.shape[0]), train_in_padded, hp
    )
    train_out_padded = jax.vmap(random_augment, in_axes=(0, 0, None))(
        jax.random.split(keys[4], train_out_padded.shape[0]), train_out_padded, hp
    )
    
    return train_in_padded, train_out_padded, eval_in_padded

def create_data_loader(tasks, hp, batch_size, is_train=True):
    def prepare_batch(key, batch):
        keys = jax.random.split(key, len(batch))
        return jax.vmap(process_single_task, in_axes=(0, 0, None))(keys, batch, hp)
    
    def data_loader(key):
        num_tasks = len(tasks)
        num_batches = int(jnp.ceil(num_tasks / batch_size))
        
        if is_train:
            key, shuffle_key = jax.random.split(key)
            indices = jax.random.permutation(shuffle_key, num_tasks)
            tasks_shuffled = [tasks[i] for i in indices]
        else:
            tasks_shuffled = tasks
        
        for i in range(num_batches):
            batch_key, key = jax.random.split(key)
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, num_tasks)
            batch = tasks_shuffled[start_idx:end_idx]
            yield prepare_batch(batch_key, batch)
    
    return data_loader

# <model>

def init_params(key, hp: Hyperparams):
    """Initialize model parameters."""
    params = {}
    keys = jax.random.split(key, 5)

    # Convolutional layers for processing training examples
    params['conv_train'] = {
        'w': jax.random.normal(keys[0], (3, 3, 2 * hp.num_channels, 16)) * 0.01,
        'b': jnp.zeros((16,))
    }
    print(f"Initialized conv_train weights with shape: {params['conv_train']['w'].shape}")

    # Convolutional layers for processing evaluation inputs
    params['conv_eval'] = {
        'w': jax.random.normal(keys[1], (3, 3, hp.num_channels, 16)) * 0.01,
        'b': jnp.zeros((16,))
    }
    print(f"Initialized conv_eval weights with shape: {params['conv_eval']['w'].shape}")

    # Fully connected layer to combine features
    combined_feature_size = 16  # From conv_train + conv_eval outputs
    params['fc'] = {
        'w': jax.random.normal(keys[2], (combined_feature_size * 2, hp.num_channels)) * 0.01,
        'b': jnp.zeros((hp.num_channels,))
    }
    print(f"Initialized fc weights with shape: {params['fc']['w'].shape}")

    return params

def model(params, task_train_in, task_train_out, task_eval_in, hp):
    """Model function."""

    # Debug prints to check shapes
    print(f"task_train_in shape: {task_train_in.shape}")
    print(f"task_train_out shape: {task_train_out.shape}")
    print(f"task_eval_in shape: {task_eval_in.shape}")

    # Combine training inputs and outputs
    batch_size = task_train_in.shape[0]
    pad_dim_time_train = task_train_in.shape[1]
    H, W = task_train_in.shape[2], task_train_in.shape[3]

    # Concatenate inputs and outputs along the channel dimension
    train_combined = jnp.concatenate([task_train_in, task_train_out], axis=-1)  # Now has 2 * hp.num_channels channels

    # Process training examples
    def process_train_example(train_example):
        x = train_example  # Shape: (T, H, W, 2 * num_channels)
        x = x.reshape(-1, H, W, 2 * hp.num_channels)  # Combine time dimension
        # Debug print
        print(f"Processing train example with shape: {x.shape}")
        x = jax.nn.relu(jax.lax.conv_general_dilated(
            x,
            params['conv_train']['w'],
            window_strides=(1, 1),
            padding='SAME',
            dimension_numbers=('NHWC', 'HWIO', 'NHWC')
        ) + params['conv_train']['b'])
        x = jnp.mean(x, axis=(0, 1, 2))  # Global average pooling
        return x  # Shape: (features,)

    train_features = jax.vmap(process_train_example)(train_combined)

    # Process evaluation inputs
    def process_eval_input(eval_input):
        x = eval_input  # Shape: (T_eval, H, W, num_channels)
        x = x.reshape(-1, H, W, hp.num_channels)
        # Debug print
        print(f"Processing eval input with shape: {x.shape}")
        x = jax.nn.relu(jax.lax.conv_general_dilated(
            x,
            params['conv_eval']['w'],
            window_strides=(1, 1),
            padding='SAME',
            dimension_numbers=('NHWC', 'HWIO', 'NHWC')
        ) + params['conv_eval']['b'])
        x = jnp.mean(x, axis=(0, 1, 2))  # Global average pooling
        return x  # Shape: (features,)

    eval_features = jax.vmap(process_eval_input)(task_eval_in)

    # Combine features
    combined_features = jnp.concatenate([train_features, eval_features], axis=-1)  # Shape: (batch_size, combined_features)

    # Fully connected layer to predict output logits
    logits = combined_features @ params['fc']['w'] + params['fc']['b']  # Shape: (batch_size, num_channels)

    # Expand logits to grid shape
    pad_dim_time_eval = task_eval_in.shape[1]
    H_eval, W_eval = task_eval_in.shape[2], task_eval_in.shape[3]
    logits = logits[:, None, None, None, :]  # Shape: (batch_size, 1, 1, 1, num_channels)
    logits = jnp.tile(logits, (1, pad_dim_time_eval, H_eval, W_eval, 1))  # Broadcast to grid shape

    # Debug print
    print(f"Final logits shape: {logits.shape}")

    return logits  # Shape: (batch_size, pad_dim_time_eval, H, W, num_channels)


# </model>

def loss_fn(task_eval_out_pred, task_eval_out_targ, hp: Hyperparams):
    """Compute the loss between predicted and target outputs."""
    # task_eval_out_pred: (batch_size, pad_dim_time_eval, H, W, num_channels)
    # task_eval_out_targ: (batch_size, pad_dim_time_eval, H, W, num_channels)
    logits = task_eval_out_pred.reshape(-1, hp.num_channels)
    labels = task_eval_out_targ.reshape(-1, hp.num_channels)
    labels = jnp.argmax(labels, axis=-1)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    loss = jnp.mean(loss)
    return loss


def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp: Hyperparams):
    def loss_and_grad(params):
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out, hp)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

def valid_step(params, valid_gen, hp: Hyperparams):
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    for task_train_in, task_train_out, task_eval_in, task_eval_out in valid_gen:
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out, hp)
        total_loss += loss
        pred_classes = jnp.argmax(task_eval_out_pred, axis=-1)
        true_classes = jnp.argmax(task_eval_out, axis=-1)
        acc = jnp.mean(pred_classes == true_classes)
        total_acc += acc
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches
    return avg_loss, avg_acc

print(f"hyperparameters: {hp}")
wandb.login()
wandb.init(
    entity=hp.wandb_entity,
    project=hp.wandb_project,
    name=f"{hp.compute_backend}.{hp.morph}.{str(uuid.uuid4())[:6]}",
    config=hp.__dict__,
)
key = jax.random.PRNGKey(hp.seed)
train_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json', hp)
valid_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json', hp)
train_loader = create_data_loader(train_tasks, hp, batch_size=hp.batch_size, is_train=True)
valid_loader = create_data_loader(valid_tasks, hp, batch_size=hp.batch_size, is_train=False)
params = init_params(key, hp)
optimizer = optax.adam(hp.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(hp.num_epochs):
    print(f"epoch {epoch + 1}/{hp.num_epochs}")
    steps_per_epoch = int(jnp.ceil(len(train_tasks) / hp.batch_size))
    train_key, key = jax.random.split(key)
    
    # Create an iterator for this epoch
    train_iter = train_loader(train_key)
    
    for step in range(steps_per_epoch):
        try:
            task_train_in, task_train_out, task_eval_in = next(train_iter)
            params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, hp)
            if step % hp.print_every == 0:
                print(f"step {step + 1}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
                wandb.log({"train_loss": train_loss.item()}, step=step + 1)
        except StopIteration:
            break
    
    valid_key, key = jax.random.split(key)
    valid_iter = valid_loader(valid_key)
    valid_loss, valid_acc = valid_step(params, valid_iter, hp)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=(epoch + 1) * steps_per_epoch)

# Submission code with corrections
submission_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
predictions = {}
with open(submission_challenges_path, 'r') as f:
    challenges_dict = json.load(f)
print(f"loading challenges from {submission_challenges_path}, found {len(challenges_dict)} challenges")
for task_id in challenges_dict.keys():
    task = challenges_dict[task_id]
    task_train_in = []
    task_train_out = []
    task_eval_in = []
    for pair in task['train']:
        task_train_in.append(pair['input'])
        task_train_out.append(pair['output'])
    for grid in task['test']:
        task_eval_in.append(grid['input'])
    # Process training data
    task_train_in_processed = []
    task_train_out_processed = []
    for t_in, t_out in zip(task_train_in, task_train_out):
        t_in = jnp.array(t_in, dtype=jnp.int32)
        t_out = jnp.array(t_out, dtype=jnp.int32)
        t_in = jax.nn.one_hot(t_in, hp.num_channels)
        t_out = jax.nn.one_hot(t_out, hp.num_channels)
        task_train_in_processed.append(t_in)
        task_train_out_processed.append(t_out)
    num_test_inputs = len(task_eval_in)
    outputs_list = [{} for _ in range(num_test_inputs)]
    for attempt_id in range(2):
        key = jax.random.PRNGKey(attempt_id)
        key, subkey1, subkey2 = jax.random.split(key, num=3)
        # Pad training data
        task_train_in_padded = []
        task_train_out_padded = []
        k1 = subkey1
        for t_in, t_out in zip(task_train_in_processed, task_train_out_processed):
            k1, subk = jax.random.split(k1)
            t_in_padded, t_out_padded, _ = pad_space(subk, t_in, t_out, hp)
            task_train_in_padded.append(t_in_padded)
            task_train_out_padded.append(t_out_padded)
        task_train_in_padded = jnp.stack(task_train_in_padded)
        task_train_out_padded = jnp.stack(task_train_out_padded)
        task_train_in_padded, task_train_out_padded = pad_time(subkey2, task_train_in_padded, task_train_out_padded, hp)
        # Process each test input
        for eval_example_id in range(num_test_inputs):
            e_in = task_eval_in[eval_example_id]
            e_in = jnp.array(e_in, dtype=jnp.int32)
            e_in_one_hot = jax.nn.one_hot(e_in, hp.num_channels)
            # Pad eval data
            key, subkey3, subkey4 = jax.random.split(key, num=3)
            e_in_padded, _, start_pos = pad_space(subkey3, e_in_one_hot, None, hp)
            e_in_padded = e_in_padded[None]  # Add time dimension
            e_in_padded, _ = pad_time(subkey4, e_in_padded, None, hp, is_eval=True)
            # Predict
            task_eval_out_pred = model(params, task_train_in_padded[None], task_train_out_padded[None], e_in_padded[None], hp)
            task_eval_out_pred = task_eval_out_pred[0, 0]  # Remove batch and time dimension
            # Un-padding and converting outputs
            grid_out_pred = jnp.argmax(task_eval_out_pred, axis=-1)
            sh, sw = start_pos
            h = e_in.shape[0]
            w = e_in.shape[1]
            grid_out_pred = grid_out_pred[sh:sh + h, sw:sw + w]
            grid_out_pred = grid_out_pred.tolist()
            outputs_list[eval_example_id][f"attempt_{attempt_id+1}"] = grid_out_pred
    predictions[task_id] = outputs_list

if hp.compute_backend in ['oop', 'ojo', 'big']:
    output_dir = f"/arcnca/output{hp.morph}"
    os.makedirs(output_dir, exist_ok=True)
else:
    # when submitting to kaggle, save the output to the current directory
    output_dir = os.getcwd()

submission_filepath = os.path.join(output_dir, "submission.json")
with open('submission.json', 'w') as f:
    json.dump(predictions, f)
wandb.save(submission_filepath)

hparams_filepath = os.path.join(output_dir, "hparams.json")
with open(hparams_filepath, 'w') as f:
    json.dump(hp.__dict__, f)
wandb.save(hparams_filepath)

results_filepath = os.path.join(output_dir, "results.json")
with open(results_filepath, 'w') as f:
    json.dump({"accuracy": valid_acc.item()}, f)

wandb.finish()