In [None]:
import json
from dataclasses import dataclass
from datetime import datetime
import os
import uuid

import jax
import jax.numpy as jnp
import optax
import wandb

@dataclass(frozen=True)
class Hyperparams:
    seed: int = int(os.environ.get("SEED", 42))
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # ---
    learning_rate: float = 1e-3
    batch_size: int = 8
    num_epochs: int = 1
    print_every: int = 1
    pad_dim_space: int = 30
    pad_dim_time_train: int = 10  # max number of training pairs
    pad_dim_time_eval: int = 3    # max number of eval examples
    pad_value: int = 0
    num_channels: int = 10
    augment_prob: float = 0.5

hp = Hyperparams()

def load_tasks(challenges_path, solutions_path, hp: Hyperparams):
    """Loads raw dataset from json files, stored in RAM."""
    tasks = []
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    for task_id in challenges_dict.keys():
        print(f"\t task {task_id}")
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        # there may be multiple training pairs for each task
        for pair in challenges_dict[task_id]['train']:
            task_train_in.append(pair['input'])
            task_train_out.append(pair['output'])
            print(f"shape of task_train_in {jnp.array(pair['input']).shape}")
            print(f"shape of task_train_out {jnp.array(pair['output']).shape}")
        for grid in challenges_dict[task_id]['test']:
            task_eval_in.append(grid['input'])
            print(f"shape of task_eval_in {jnp.array(grid['input']).shape}")
        if task_id in solutions_dict:
            for grid in solutions_dict[task_id]:
                task_eval_out.append(grid)
                print(f"shape of task_eval_out {jnp.array(grid).shape}")
            assert len(task_eval_in) == len(task_eval_out)
        else:
            task_eval_out = [None] * len(task_eval_in)
        assert len(task_train_in) == len(task_train_out)
        assert len(task_eval_in) <= hp.pad_dim_time_eval
        assert len(task_train_in) <= hp.pad_dim_time_train
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

def pad_space(key, grid_in, grid_out, hp: Hyperparams):
    # grids are [H, W, C], pick random start position in space based on max grid shape
    grid_in_h, grid_in_w = grid_in.shape[:2]
    if grid_out is not None:
        grid_out_h, grid_out_w = grid_out.shape[:2]
        max_h = max(grid_in_h, grid_out_h)
        max_w = max(grid_in_w, grid_out_w)
    else:
        max_h, max_w = grid_in_h, grid_in_w
    max_h = min(max_h, hp.pad_dim_space)
    max_w = min(max_w, hp.pad_dim_space)
    assert max_h <= hp.pad_dim_space
    assert max_w <= hp.pad_dim_space
    start_h = jax.random.randint(key, shape=(), minval=0, maxval=hp.pad_dim_space - max_h + 1)
    start_w = jax.random.randint(key, shape=(), minval=0, maxval=hp.pad_dim_space - max_w + 1)
    pad_height_in = (start_h, hp.pad_dim_space - grid_in_h - start_h)
    pad_width_in = (start_w, hp.pad_dim_space - grid_in_w - start_w)
    pad_width_in = (pad_height_in, pad_width_in, (0, 0))
    grid_in = jnp.pad(grid_in, pad_width_in, mode='constant', constant_values=hp.pad_value)
    if grid_out is not None:
        pad_height_out = (start_h, hp.pad_dim_space - grid_out_h - start_h)
        pad_width_out = (start_w, hp.pad_dim_space - grid_out_w - start_w)
        pad_width_out = (pad_height_out, pad_width_out, (0, 0))
        grid_out = jnp.pad(grid_out, pad_width_out, mode='constant', constant_values=hp.pad_value)
    else:
        grid_out = None
    return grid_in, grid_out, (start_h, start_w)

def pad_time(key, grid_in, grid_out, hp: Hyperparams, is_eval=False):
    # different padding in time for eval and train examples
    if is_eval:
        pad_dim_time = hp.pad_dim_time_eval
    else:
        pad_dim_time = hp.pad_dim_time_train
    # grids are [T, H, W, C], repeat random example in time until T == pad_dim_time
    assert grid_in.shape[0] <= pad_dim_time
    assert grid_in.shape[1] <= hp.pad_dim_space
    assert grid_in.shape[2] <= hp.pad_dim_space
    grid_in_list = [grid_in[i] for i in range(grid_in.shape[0])]
    grid_out_list = [grid_out[i] for i in range(grid_out.shape[0])] if grid_out is not None else None
    for _ in range(pad_dim_time - grid_in.shape[0]):
        key, subkey = jax.random.split(key)
        repeat_example_idx = jax.random.randint(subkey, shape=(), minval=0, maxval=grid_in.shape[0])
        grid_in_list.append(grid_in[repeat_example_idx])
        if grid_out is not None:
            grid_out_list.append(grid_out[repeat_example_idx])
    grid_in = jnp.stack(grid_in_list)
    if grid_out is not None:
        grid_out = jnp.stack(grid_out_list)
    # shuffle the T dimension
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, pad_dim_time)
    grid_in = grid_in[perm]
    if grid_out is not None:
        grid_out = grid_out[perm]
    return grid_in, grid_out

# Remove rotation for rectangular grids
augmentations = [jnp.fliplr, jnp.flipud]

def random_augment(key, grid_in, grid_out, hp: Hyperparams):
    apply = jax.random.bernoulli(key, hp.augment_prob, shape=(len(augmentations),))
    for i, augmentation in enumerate(augmentations):
        grid_in = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_in)
        if grid_out is not None:
            grid_out = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_out)
    # shuffle the C dimension
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, hp.num_channels)
    grid_in = grid_in[..., perm]
    if grid_out is not None:
        grid_out = grid_out[..., perm]
    return grid_in, grid_out

def data_loader(key, tasks, hp: Hyperparams, train_mode=True):
    num_tasks = len(tasks)
    if train_mode:
        key, subkey = jax.random.split(key)
        indices = jax.random.permutation(subkey, num_tasks)
        tasks = [tasks[idx] for idx in indices]
    num_batches = int(jnp.ceil(num_tasks / hp.batch_size))
    for i in range(num_batches):
        batch_key, key = jax.random.split(key)
        start_idx = i * hp.batch_size
        end_idx = min((i + 1) * hp.batch_size, num_tasks)
        current_batch_size = end_idx - start_idx
        subkeys = jax.random.split(batch_key, current_batch_size * 5)
        subkeys = subkeys.reshape(current_batch_size, 5, 2)
        batch_train_in = []
        batch_train_out = []
        batch_eval_in = []
        batch_eval_out = []
        for idx_in_batch, (j, subkey) in enumerate(zip(range(start_idx, end_idx), subkeys)):
            k1, k2, k3, k4, k5 = subkey
            k1 = jax.random.PRNGKey(k1[0])
            k2 = jax.random.PRNGKey(k2[0])
            k3 = jax.random.PRNGKey(k3[0])
            k4 = jax.random.PRNGKey(k4[0])
            k5 = jax.random.PRNGKey(k5[0])

            # Convert each grid to a JAX array individually
            task_train_in = [jnp.array(grid, dtype=jnp.int32) for grid in tasks[j][0]]
            task_train_out = [jnp.array(grid, dtype=jnp.int32) for grid in tasks[j][1]]
            task_eval_in = [jnp.array(grid, dtype=jnp.int32) for grid in tasks[j][2]]
            if tasks[j][3][0] is not None:
                task_eval_out = [jnp.array(grid, dtype=jnp.int32) for grid in tasks[j][3]]
            else:
                task_eval_out = None

            # Convert grids to one-hot encoding individually
            task_train_in = [jax.nn.one_hot(grid, hp.num_channels) for grid in task_train_in]
            task_train_out = [jax.nn.one_hot(grid, hp.num_channels) for grid in task_train_out]
            task_eval_in = [jax.nn.one_hot(grid, hp.num_channels) for grid in task_eval_in]
            if task_eval_out is not None:
                task_eval_out = [jax.nn.one_hot(grid, hp.num_channels) for grid in task_eval_out]

            # Pad in space
            task_train_in_padded = []
            task_train_out_padded = []
            k1_grid = k1
            for t_in, t_out in zip(task_train_in, task_train_out):
                k1_grid, subk = jax.random.split(k1_grid)
                t_in_padded, t_out_padded, _ = pad_space(subk, t_in, t_out, hp)
                task_train_in_padded.append(t_in_padded)
                task_train_out_padded.append(t_out_padded)
            task_train_in_padded = jnp.stack(task_train_in_padded)
            task_train_out_padded = jnp.stack(task_train_out_padded)

            # Pad in time
            task_train_in_padded, task_train_out_padded = pad_time(k2, task_train_in_padded, task_train_out_padded, hp)

            # Same for eval data
            task_eval_in_padded = []
            task_eval_out_padded = []
            k3_grid = k3
            for idx, e_in in enumerate(task_eval_in):
                k3_grid, subk = jax.random.split(k3_grid)
                e_out = task_eval_out[idx] if task_eval_out is not None else None
                e_in_padded, e_out_padded, _ = pad_space(subk, e_in, e_out, hp)
                task_eval_in_padded.append(e_in_padded)
                if e_out_padded is not None:
                    task_eval_out_padded.append(e_out_padded)
            task_eval_in_padded = jnp.stack(task_eval_in_padded)
            if task_eval_out_padded:
                task_eval_out_padded = jnp.stack(task_eval_out_padded)
            else:
                task_eval_out_padded = None
            task_eval_in_padded, task_eval_out_padded = pad_time(k4, task_eval_in_padded, task_eval_out_padded, hp, is_eval=True)

            if train_mode:
                k4_aug, k5_aug = jax.random.split(k5)
                task_train_in_padded, task_train_out_padded = random_augment(k4_aug, task_train_in_padded, task_train_out_padded, hp)
                task_eval_in_padded, task_eval_out_padded = random_augment(k5_aug, task_eval_in_padded, task_eval_out_padded, hp)

            batch_train_in.append(task_train_in_padded)
            batch_train_out.append(task_train_out_padded)
            batch_eval_in.append(task_eval_in_padded)
            if task_eval_out_padded is not None:
                batch_eval_out.append(task_eval_out_padded)
            else:
                batch_eval_out.append(jnp.zeros_like(task_eval_in_padded))

        task_train_in = jnp.stack(batch_train_in)
        task_train_out = jnp.stack(batch_train_out)
        task_eval_in = jnp.stack(batch_eval_in)
        task_eval_out = jnp.stack(batch_eval_out)
        yield task_train_in, task_train_out, task_eval_in, task_eval_out

# <model>

def init_params(key, hp: Hyperparams):
    """Initialize model parameters."""
    params = {}
    keys = jax.random.split(key, num=8)
    # Encoder for training pairs (input and output grids concatenated along channel dimension)
    params['encoder'] = {
        'conv1': conv_layer_params(keys[0], in_channels=hp.num_channels * 2, out_channels=32, kernel_size=3),
        'conv2': conv_layer_params(keys[1], in_channels=32, out_channels=64, kernel_size=3),
        'conv3': conv_layer_params(keys[2], in_channels=64, out_channels=128, kernel_size=3),
    }
    # Encoder for test input grid
    params['test_encoder'] = {
        'conv1': conv_layer_params(keys[3], in_channels=hp.num_channels, out_channels=32, kernel_size=3),
        'conv2': conv_layer_params(keys[4], in_channels=32, out_channels=64, kernel_size=3),
        'conv3': conv_layer_params(keys[5], in_channels=64, out_channels=128, kernel_size=3),
    }
    # Decoder to produce the output grid
    params['decoder'] = mlp_layer_params(keys[6], in_dim=256, out_dim=hp.pad_dim_space * hp.pad_dim_space * hp.num_channels)
    return params

def conv_layer_params(key, in_channels, out_channels, kernel_size):
    """Initialize convolutional layer parameters."""
    w_key, b_key = jax.random.split(key)
    w_shape = (kernel_size, kernel_size, in_channels, out_channels)
    w = jax.random.normal(w_key, w_shape) * jnp.sqrt(2.0 / (kernel_size * kernel_size * in_channels))
    b = jnp.zeros(out_channels)
    return {'w': w, 'b': b}

def mlp_layer_params(key, in_dim, out_dim):
    """Initialize MLP layer parameters."""
    w_key, b_key = jax.random.split(key)
    w = jax.random.normal(w_key, (in_dim, out_dim)) * jnp.sqrt(2.0 / in_dim)
    b = jnp.zeros(out_dim)
    return {'w': w, 'b': b}

def model(params, task_train_in, task_train_out, task_eval_in, hp: Hyperparams):
    """Model that predicts the task_eval_out given training pairs and task_eval_in."""

    def process_training_examples(task_train_in, task_train_out):
        # Concatenate input and output grids along the channel dimension
        concatenated = jnp.concatenate([task_train_in, task_train_out], axis=-1)  # (pad_dim_time, H, W, num_channels * 2)

        # Encode each training example
        def encode_grid(grid):
            h = grid
            h = jax.lax.conv_general_dilated(
                h[None], params['encoder']['conv1']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['encoder']['conv1']['b']
            h = jax.nn.relu(h)
            h = jax.lax.conv_general_dilated(
                h, params['encoder']['conv2']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['encoder']['conv2']['b']
            h = jax.nn.relu(h)
            h = jax.lax.conv_general_dilated(
                h, params['encoder']['conv3']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['encoder']['conv3']['b']
            h = jax.nn.relu(h)
            h = jnp.mean(h, axis=(1,2))  # Global average pooling
            return h[0]  # (128,)

        # Apply to each training example and aggregate
        representations = jax.vmap(encode_grid)(concatenated)
        aggregated = jnp.mean(representations, axis=0)  # (128,)
        return aggregated

    def process_test_input(task_eval_in):
        # Encode the test input grids
        def encode_grid(grid):
            h = grid
            h = jax.lax.conv_general_dilated(
                h[None], params['test_encoder']['conv1']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['test_encoder']['conv1']['b']
            h = jax.nn.relu(h)
            h = jax.lax.conv_general_dilated(
                h, params['test_encoder']['conv2']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['test_encoder']['conv2']['b']
            h = jax.nn.relu(h)
            h = jax.lax.conv_general_dilated(
                h, params['test_encoder']['conv3']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['test_encoder']['conv3']['b']
            h = jax.nn.relu(h)
            h = jnp.mean(h, axis=(1,2))  # Global average pooling
            return h[0]  # (128,)
        representations = jax.vmap(encode_grid)(task_eval_in)  # (pad_dim_time_eval, 128)
        return representations

    def process_task(task_train_in, task_train_out, task_eval_in):
        train_repr = process_training_examples(task_train_in, task_train_out)  # (128,)
        test_repr = process_test_input(task_eval_in)  # (pad_dim_time_eval, 128)
        # We need to produce output grids for each test input
        def decode(test_r):
            combined = jnp.concatenate([train_repr, test_r], axis=0)  # (256,)
            logits = jnp.dot(combined, params['decoder']['w']) + params['decoder']['b']  # (H*W*C,)
            output_grid = logits.reshape(hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)  # (H, W, C)
            return output_grid
        outputs = jax.vmap(decode)(test_repr)
        return outputs  # (pad_dim_time_eval, H, W, C)

    # Apply to each task in the batch
    task_eval_out_pred = jax.vmap(process_task, in_axes=(0, 0, 0))(task_train_in, task_train_out, task_eval_in)
    # task_eval_out_pred: (batch_size, pad_dim_time_eval, H, W, C)
    return task_eval_out_pred

def loss_fn(task_eval_out_pred, task_eval_out_targ, hp: Hyperparams):
    """Compute the loss between predicted and target outputs."""
    # task_eval_out_pred: (batch_size, pad_dim_time_eval, H, W, num_channels)
    # task_eval_out_targ: (batch_size, pad_dim_time_eval, H, W, num_channels)
    logits = task_eval_out_pred.reshape(-1, hp.num_channels)
    labels = task_eval_out_targ.reshape(-1, hp.num_channels)
    labels = jnp.argmax(labels, axis=-1)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    loss = jnp.mean(loss)
    return loss

# </model>

def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp: Hyperparams):
    def loss_and_grad(params):
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out, hp)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

def valid_step(params, valid_gen, hp: Hyperparams):
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    for task_train_in, task_train_out, task_eval_in, task_eval_out in valid_gen:
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out, hp)
        total_loss += loss
        pred_classes = jnp.argmax(task_eval_out_pred, axis=-1)
        true_classes = jnp.argmax(task_eval_out, axis=-1)
        acc = jnp.mean(pred_classes == true_classes)
        total_acc += acc
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches
    return avg_loss, avg_acc

print(f"hyperparameters: {hp}")
wandb.login()
wandb.init(
    entity=hp.wandb_entity,
    project=hp.wandb_project,
    name=f"{hp.morph}.{hp.compute_backend}.{str(uuid.uuid4())[:6]}",
    config=hp.__dict__,
)
key = jax.random.PRNGKey(hp.seed)
train_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json', hp)
valid_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json', hp)
params = init_params(key, hp)
optimizer = optax.adam(hp.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(hp.num_epochs):
    print(f"epoch {epoch + 1}/{hp.num_epochs}")
    steps_per_epoch = int(jnp.ceil(len(train_tasks) / hp.batch_size))
    train_key, key = jax.random.split(key)
    train_gen = data_loader(train_key, train_tasks, hp, train_mode=True)
    for step in range(steps_per_epoch):
        try:
            task_train_in, task_train_out, task_eval_in, task_eval_out = next(train_gen)
        except StopIteration:
            break
        params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp)
        if step % hp.print_every == 0:
            print(f"step {step + 1}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
            wandb.log({"train_loss": train_loss.item()}, step=step + 1)
    valid_key, key = jax.random.split(key)
    valid_gen = data_loader(valid_key, valid_tasks, hp, train_mode=False)
    valid_loss, valid_acc = valid_step(params, valid_gen, hp)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=(epoch + 1) * steps_per_epoch)
wandb.finish()

# Submission code with corrections
submission_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
predictions = {}
with open(submission_challenges_path, 'r') as f:
    challenges_dict = json.load(f)
print(f"loading challenges from {submission_challenges_path}, found {len(challenges_dict)} challenges")
for task_id in challenges_dict.keys():
    task = challenges_dict[task_id]
    task_train_in = []
    task_train_out = []
    task_eval_in = []
    for pair in task['train']:
        task_train_in.append(pair['input'])
        task_train_out.append(pair['output'])
    for grid in task['test']:
        task_eval_in.append(grid['input'])
    # Process training data
    task_train_in_processed = []
    task_train_out_processed = []
    for t_in, t_out in zip(task_train_in, task_train_out):
        t_in = jnp.array(t_in, dtype=jnp.int32)
        t_out = jnp.array(t_out, dtype=jnp.int32)
        t_in = jax.nn.one_hot(t_in, hp.num_channels)
        t_out = jax.nn.one_hot(t_out, hp.num_channels)
        task_train_in_processed.append(t_in)
        task_train_out_processed.append(t_out)
    num_test_inputs = len(task_eval_in)
    outputs_list = [{} for _ in range(num_test_inputs)]
    for attempt_id in range(2):
        key = jax.random.PRNGKey(attempt_id)
        key, subkey1, subkey2 = jax.random.split(key, num=3)
        # Pad training data
        task_train_in_padded = []
        task_train_out_padded = []
        k1 = subkey1
        for t_in, t_out in zip(task_train_in_processed, task_train_out_processed):
            k1, subk = jax.random.split(k1)
            t_in_padded, t_out_padded, _ = pad_space(subk, t_in, t_out, hp)
            task_train_in_padded.append(t_in_padded)
            task_train_out_padded.append(t_out_padded)
        task_train_in_padded = jnp.stack(task_train_in_padded)
        task_train_out_padded = jnp.stack(task_train_out_padded)
        task_train_in_padded, task_train_out_padded = pad_time(subkey2, task_train_in_padded, task_train_out_padded, hp)
        # Process each test input
        for eval_example_id in range(num_test_inputs):
            e_in = task_eval_in[eval_example_id]
            e_in = jnp.array(e_in, dtype=jnp.int32)
            e_in_one_hot = jax.nn.one_hot(e_in, hp.num_channels)
            # Pad eval data
            key, subkey3, subkey4 = jax.random.split(key, num=3)
            e_in_padded, _, start_pos = pad_space(subkey3, e_in_one_hot, None, hp)
            e_in_padded = e_in_padded[None]  # Add time dimension
            e_in_padded, _ = pad_time(subkey4, e_in_padded, None, hp, is_eval=True)
            # Predict
            task_eval_out_pred = model(params, task_train_in_padded[None], task_train_out_padded[None], e_in_padded[None], hp)
            task_eval_out_pred = task_eval_out_pred[0, 0]  # Remove batch and time dimension
            # Un-padding and converting outputs
            grid_out_pred = jnp.argmax(task_eval_out_pred, axis=-1)
            sh, sw = start_pos
            h = e_in.shape[0]
            w = e_in.shape[1]
            grid_out_pred = grid_out_pred[sh:sh + h, sw:sw + w]
            grid_out_pred = grid_out_pred.tolist()
            outputs_list[eval_example_id][f"attempt_{attempt_id+1}"] = grid_out_pred
    predictions[task_id] = outputs_list

with open('submission.json', 'w') as f:
    json.dump(predictions, f)
