In [None]:
import json
from dataclasses import dataclass
from datetime import datetime
import os
import uuid

import jax
import jax.numpy as jnp
import optax
import wandb

@dataclass(frozen=True)
class Hyperparams:
    seed: int = int(os.environ.get("SEED", 42))
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # ---
    learning_rate: float = 1e-3
    batch_size: int = 8
    num_epochs: int = 1
    print_every: int = 1
    pad_dim_space: int = 30
    pad_dim_time_train: int = 10  # max number of training pairs
    pad_dim_time_eval: int = 3    # max number of eval examples
    pad_value: int = 0
    num_channels: int = 10
    augment_prob: float = 0.5

hp = Hyperparams()

def load_tasks(challenges_path, solutions_path, hp: Hyperparams):
    """Loads raw dataset from json files, stored in RAM."""
    tasks = []
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    for task_id in challenges_dict.keys():
        print(f"\t task {task_id}")
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        # there may be multiple training pairs for each task
        for pair in challenges_dict[task_id]['train']:
            task_train_in.append(pair['input'])
            task_train_out.append(pair['output'])
            print(f"shape of task_train_in {jnp.array(pair['input']).shape}")
            print(f"shape of task_train_out {jnp.array(pair['output']).shape}")
        for grid in challenges_dict[task_id]['test']:
            task_eval_in.append(grid['input'])
            print(f"shape of task_eval_in {jnp.array(grid['input']).shape}")
        for grid in solutions_dict[task_id]:
            task_eval_out.append(grid)
            print(f"shape of task_eval_out {jnp.array(grid).shape}")
        assert len(task_eval_in) == len(task_eval_out)
        assert len(task_train_in) == len(task_train_out)
        assert len(task_eval_in) <= hp.pad_dim_time_eval
        assert len(task_train_in) <= hp.pad_dim_time_train
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

def pad_space(key, grid_in, grid_out, hp: Hyperparams):
    # grids are [H, W, C], pick random start position in space based on grid shape
    assert grid_in.shape[0] <= hp.pad_dim_space
    assert grid_in.shape[1] <= hp.pad_dim_space
    assert grid_in.shape[0] == grid_in.shape[1]
    assert grid_out is None or grid_out.shape == grid_in.shape
    start_h = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid_in.shape[0] + 1)
    start_w = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid_in.shape[1] + 1)
    pad_width = ((start_h, hp.pad_dim_space - grid_in.shape[0] - start_h),
                 (start_w, hp.pad_dim_space - grid_in.shape[1] - start_w),
                 (0, 0))
    grid_in = jnp.pad(grid_in, pad_width, mode='constant', constant_values=hp.pad_value)
    if grid_out is not None:
        grid_out = jnp.pad(grid_out, pad_width, mode='constant', constant_values=hp.pad_value)
    return grid_in, grid_out, (start_h, start_w)

def pad_time(key, grid_in, grid_out, hp: Hyperparams, is_eval=False):
    # different padding in time for eval and train examples
    if is_eval:
        pad_dim_time = hp.pad_dim_time_eval
    else:
        pad_dim_time = hp.pad_dim_time_train
    # grids are [T, H, W, C], repeat random example in time until T == pad_dim_time
    assert grid_in.shape[0] <= pad_dim_time
    assert grid_in.shape[1] <= hp.pad_dim_space
    assert grid_in.shape[2] <= hp.pad_dim_space
    assert grid_in.shape[1] == grid_in.shape[2]
    assert grid_out is None or grid_out.shape == grid_in.shape
    for _ in range(pad_dim_time - grid_in.shape[0]):
        repeat_example_idx = jax.random.randint(key, minval=0, maxval=grid_in.shape[0])
        grid_in = jnp.concatenate([grid_in, grid_in[repeat_example_idx][None]], axis=0)
        if grid_out is not None:
            grid_out = jnp.concatenate([grid_out, grid_out[repeat_example_idx][None]], axis=0)
    # shuffle the T dimension
    perm = jax.random.permutation(key, pad_dim_time)
    grid_in = grid_in[perm]
    if grid_out is not None:
        grid_out = grid_out[perm]
    return grid_in, grid_out

augmentations = [jnp.fliplr, jnp.flipud, jnp.rot90]

def random_augment(key, grid_in, grid_out, hp: Hyperparams):
    assert grid_in.shape == grid_out.shape
    assert grid_in.shape == (hp.pad_dim_time, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
    apply = jax.random.bernoulli(key, hp.augment_prob, shape=(len(augmentations),))
    for i, augmentation in enumerate(augmentations):
        grid_in = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_in)
        if grid_out is not None:
            grid_out = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_out)
    # shuffle the C dimension
    perm = jax.random.permutation(key, hp.num_channels)
    grid_in = grid_in[..., perm]
    if grid_out is not None:
        grid_out = grid_out[..., perm]
    return grid_in, grid_out

def data_loader(key, tasks, hp: Hyperparams, train_mode=True):
    num_tasks = len(tasks)
    if train_mode:
        indices = jax.random.permutation(key, num_tasks)
        tasks = [tasks[idx] for idx in indices]
    num_batches = int(jnp.ceil(num_tasks / hp.batch_size))
    for i in range(num_batches):
        batch_key, key = jax.random.split(key)
        start_idx = i * hp.batch_size
        end_idx = min((i + 1) * hp.batch_size, num_tasks)
        current_batch_size = end_idx - start_idx
        subkeys = jax.random.split(batch_key, current_batch_size * 5)
        batch_train_in = []
        batch_train_out = []
        batch_eval_in = []
        batch_eval_out = []
        for j, subkey in zip(range(start_idx, end_idx), subkeys.reshape(current_batch_size, 5, -1)):
            k1, k2, k3, k4, k5 = subkey
            task_train_in = jnp.array(tasks[j][0], dtype=jnp.int32)
            task_train_out = jnp.array(tasks[j][1], dtype=jnp.int32)
            task_eval_in = jnp.array(tasks[j][2], dtype=jnp.int32)
            task_eval_out = jnp.array(tasks[j][3], dtype=jnp.int32)
            # convert grids to one-hot (channel dimension)
            task_train_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_in)
            task_train_out = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_out)
            task_eval_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_eval_in)
            task_eval_out = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_eval_out)
            # pad in space
            task_train_in_padded = []
            task_train_out_padded = []
            for t_in, t_out in zip(task_train_in, task_train_out):
                t_in_padded, t_out_padded, _ = pad_space(k1, t_in, t_out, hp)
                task_train_in_padded.append(t_in_padded)
                task_train_out_padded.append(t_out_padded)
            task_train_in_padded = jnp.stack(task_train_in_padded)
            task_train_out_padded = jnp.stack(task_train_out_padded)
            # pad in time
            task_train_in_padded, task_train_out_padded = pad_time(k2, task_train_in_padded, task_train_out_padded, hp)
            # same for eval data
            task_eval_in_padded = []
            task_eval_out_padded = []
            for e_in, e_out in zip(task_eval_in, task_eval_out):
                e_in_padded, e_out_padded, _ = pad_space(k3, e_in, e_out, hp)
                task_eval_in_padded.append(e_in_padded)
                task_eval_out_padded.append(e_out_padded)
            task_eval_in_padded = jnp.stack(task_eval_in_padded)
            task_eval_out_padded = jnp.stack(task_eval_out_padded)
            task_eval_in_padded, task_eval_out_padded = pad_time(k4, task_eval_in_padded, task_eval_out_padded, hp, is_eval=True)
            if train_mode:
                task_train_in_padded, task_train_out_padded = random_augment(k4, task_train_in_padded, task_train_out_padded, hp)
                task_eval_in_padded, task_eval_out_padded = random_augment(k5, task_eval_in_padded, task_eval_out_padded, hp)
            batch_train_in.append(task_train_in_padded)
            batch_train_out.append(task_train_out_padded)
            batch_eval_in.append(task_eval_in_padded)
            batch_eval_out.append(task_eval_out_padded)
        task_train_in = jnp.stack(batch_train_in)
        assert task_train_in.shape == (current_batch_size, hp.pad_dim_time_train, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        task_train_out = jnp.stack(batch_train_out)
        assert task_train_out.shape == (current_batch_size, hp.pad_dim_time_train, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        task_eval_in = jnp.stack(batch_eval_in)
        assert task_eval_in.shape == (current_batch_size, hp.pad_dim_time_eval, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        task_eval_out = jnp.stack(batch_eval_out)
        assert task_eval_out.shape == (current_batch_size, hp.pad_dim_time_eval, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        yield task_train_in, task_train_out, task_eval_in, task_eval_out

# <model>

def init_params(key, hp: Hyperparams):
    """Initialize model parameters."""
    params = {}
    keys = jax.random.split(key, num=8)
    # Encoder for training pairs (input and output grids concatenated along channel dimension)
    params['encoder'] = {
        'conv1': conv_layer_params(keys[0], in_channels=hp.num_channels * 2, out_channels=32, kernel_size=3),
        'conv2': conv_layer_params(keys[1], in_channels=32, out_channels=64, kernel_size=3),
        'conv3': conv_layer_params(keys[2], in_channels=64, out_channels=128, kernel_size=3),
    }
    # Encoder for test input grid
    params['test_encoder'] = {
        'conv1': conv_layer_params(keys[3], in_channels=hp.num_channels, out_channels=32, kernel_size=3),
        'conv2': conv_layer_params(keys[4], in_channels=32, out_channels=64, kernel_size=3),
        'conv3': conv_layer_params(keys[5], in_channels=64, out_channels=128, kernel_size=3),
    }
    # Decoder to produce the output grid
    params['decoder'] = mlp_layer_params(keys[6], in_dim=256, out_dim=hp.pad_dim_space * hp.pad_dim_space * hp.num_channels)
    return params

def conv_layer_params(key, in_channels, out_channels, kernel_size):
    """Initialize convolutional layer parameters."""
    w_key, b_key = jax.random.split(key)
    w_shape = (kernel_size, kernel_size, in_channels, out_channels)
    w = jax.random.normal(w_key, w_shape) * jnp.sqrt(2.0 / (kernel_size * kernel_size * in_channels))
    b = jnp.zeros(out_channels)
    return {'w': w, 'b': b}

def mlp_layer_params(key, in_dim, out_dim):
    """Initialize MLP layer parameters."""
    w_key, b_key = jax.random.split(key)
    w = jax.random.normal(w_key, (in_dim, out_dim)) * jnp.sqrt(2.0 / in_dim)
    b = jnp.zeros(out_dim)
    return {'w': w, 'b': b}

def model(params, task_train_in, task_train_out, task_eval_in, hp: Hyperparams):
    """Model that predicts the task_eval_out given training pairs and task_eval_in."""

    def process_training_examples(task_train_in, task_train_out):
        # Concatenate input and output grids along the channel dimension
        concatenated = jnp.concatenate([task_train_in, task_train_out], axis=-1)  # (pad_dim_time, H, W, num_channels * 2)

        # Encode each training example
        def encode_grid(grid):
            h = grid
            h = jax.lax.conv_general_dilated(
                h[None], params['encoder']['conv1']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['encoder']['conv1']['b']
            h = jax.nn.relu(h)
            h = jax.lax.conv_general_dilated(
                h, params['encoder']['conv2']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['encoder']['conv2']['b']
            h = jax.nn.relu(h)
            h = jax.lax.conv_general_dilated(
                h, params['encoder']['conv3']['w'], window_strides=(1,1), padding='SAME'
            )
            h = h + params['encoder']['conv3']['b']
            h = jax.nn.relu(h)
            h = jnp.mean(h, axis=(1,2))  # Global average pooling
            return h[0]  # (128,)

        # Apply to each training example and aggregate
        representations = jax.vmap(encode_grid)(concatenated)
        aggregated = jnp.mean(representations, axis=0)  # (128,)
        return aggregated

    def process_test_input(task_eval_in):
        # Encode the test input grid
        grid = task_eval_in[0]
        h = grid
        h = jax.lax.conv_general_dilated(
            h[None], params['test_encoder']['conv1']['w'], window_strides=(1,1), padding='SAME'
        )
        h = h + params['test_encoder']['conv1']['b']
        h = jax.nn.relu(h)
        h = jax.lax.conv_general_dilated(
            h, params['test_encoder']['conv2']['w'], window_strides=(1,1), padding='SAME'
        )
        h = h + params['test_encoder']['conv2']['b']
        h = jax.nn.relu(h)
        h = jax.lax.conv_general_dilated(
            h, params['test_encoder']['conv3']['w'], window_strides=(1,1), padding='SAME'
        )
        h = h + params['test_encoder']['conv3']['b']
        h = jax.nn.relu(h)
        h = jnp.mean(h, axis=(1,2))  # Global average pooling
        return h[0]  # (128,)

    def process_task(task_train_in, task_train_out, task_eval_in):
        train_repr = process_training_examples(task_train_in, task_train_out)  # (128,)
        test_repr = process_test_input(task_eval_in)  # (128,)
        combined = jnp.concatenate([train_repr, test_repr], axis=0)  # (256,)
        # Decoder
        logits = jnp.dot(combined, params['decoder']['w']) + params['decoder']['b']  # (H*W*C,)
        output_grid = logits.reshape(hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)  # (H, W, C)
        return output_grid

    task_eval_out_pred = jax.vmap(process_task, in_axes=(0, 0, 0))(task_train_in, task_train_out, task_eval_in)
    return task_eval_out_pred  # (batch_size, H, W, num_channels)

def loss_fn(task_eval_out_pred, task_eval_out_targ, hp: Hyperparams):
    """Compute the loss between predicted and target outputs."""
    logits = task_eval_out_pred.reshape(-1, hp.num_channels)
    labels = task_eval_out_targ.reshape(-1, hp.num_channels)
    labels = jnp.argmax(labels, axis=-1)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    loss = jnp.mean(loss)
    return loss

# </model>

def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp: Hyperparams):
    def loss_and_grad(params):
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out.squeeze(1), hp)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

def valid_step(params, valid_gen, hp: Hyperparams):
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    for task_train_in, task_train_out, task_eval_in, task_eval_out in valid_gen:
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out.squeeze(1), hp)
        total_loss += loss
        pred_classes = jnp.argmax(task_eval_out_pred, axis=-1)
        true_classes = jnp.argmax(task_eval_out.squeeze(1), axis=-1)
        acc = jnp.mean(pred_classes == true_classes)
        total_acc += acc
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches
    return avg_loss, avg_acc

print(f"hyperparameters: {hp}")
wandb.login()
wandb.init(
    entity=hp.wandb_entity,
    project=hp.wandb_project,
    name=f"{hp.morph}.{hp.compute_backend}.{str(uuid.uuid4())[:6]}",
    config=hp.__dict__,
)
key = jax.random.PRNGKey(hp.seed)
train_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json', hp)
valid_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json', hp)
params = init_params(key, hp)
optimizer = optax.adam(hp.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(hp.num_epochs):
    print(f"epoch {epoch + 1}/{hp.num_epochs}")
    steps_per_epoch = int(jnp.ceil(len(train_tasks) / hp.batch_size))
    train_key, key = jax.random.split(key)
    train_gen = data_loader(train_key, train_tasks, hp, train_mode=True)
    for step in range(steps_per_epoch):
        task_train_in, task_train_out, task_eval_in, task_eval_out = next(train_gen)
        params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp)
        if step % hp.print_every == 0:
            print(f"step {step + 1}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
            wandb.log({"train_loss": train_loss.item()}, step=step + 1)
    valid_key, key = jax.random.split(key)
    valid_gen = data_loader(valid_key, valid_tasks, hp, train_mode=False)
    valid_loss, valid_acc = valid_step(params, valid_gen, hp)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=(epoch + 1) * steps_per_epoch)
wandb.finish()

submission_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
predictions = {}
with open(submission_challenges_path, 'r') as f:
    challenges_dict = json.load(f)
print(f"loading challenges from {submission_challenges_path}, found {len(challenges_dict)} challenges")
for task_id in challenges_dict.keys():
    task = challenges_dict[task_id]
    task_train_in = []
    task_train_out = []
    task_eval_in = []
    for pair in task['train']:
        task_train_in.append(pair['input'])
        task_train_out.append(pair['output'])
    for grid in task['test']:
        task_eval_in.append(grid['input'])
    task_train_in = jnp.array(task_train_in, dtype=jnp.int32)
    task_train_out = jnp.array(task_train_out, dtype=jnp.int32)
    task_eval_in = jnp.array(task_eval_in, dtype=jnp.int32)
    task_train_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_in)
    task_train_out = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_out)
    task_eval_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_eval_in)
    task_attempt = {}
    for attempt_id in range(2):
        key = jax.random.PRNGKey(attempt_id)
        # Pad training data
        task_train_in_padded = []
        task_train_out_padded = []
        for t_in, t_out in zip(task_train_in, task_train_out):
            t_in_padded, t_out_padded, _ = pad_space(key, t_in, t_out, hp)
            task_train_in_padded.append(t_in_padded)
            task_train_out_padded.append(t_out_padded)
        task_train_in_padded = jnp.stack(task_train_in_padded)
        task_train_out_padded = jnp.stack(task_train_out_padded)
        task_train_in_padded, task_train_out_padded = pad_time(key, task_train_in_padded, task_train_out_padded, hp)
        # Pad eval data
        task_eval_in_padded = []
        start_positions = []
        for e_in in task_eval_in:
            e_in_padded, _, start_pos = pad_space(key, e_in, None, hp)
            task_eval_in_padded.append(e_in_padded)
            start_positions.append(start_pos)
        task_eval_in_padded = jnp.stack(task_eval_in_padded)
        task_eval_in_padded, _ = pad_time(key, task_eval_in_padded, None, hp, is_eval=True)
        # Predict
        task_eval_out_pred = model(params, task_train_in_padded[None], task_train_out_padded[None], task_eval_in_padded[None], hp)
        task_eval_out_pred = task_eval_out_pred[0]  # Remove batch dimension
        # Un-padding and converting outputs
        outputs = []
        for eval_example_id in range(task_eval_in_padded.shape[0]):
            grid_out_pred = task_eval_out_pred[eval_example_id]
            grid_out_pred = jnp.argmax(grid_out_pred, axis=-1)
            sh, sw = start_positions[eval_example_id]
            h = task_eval_in[eval_example_id].shape[0]
            w = task_eval_in[eval_example_id].shape[1]
            grid_out_pred = grid_out_pred[sh:sh + h, sw:sw + w]
            grid_out_pred = grid_out_pred.tolist()
            outputs.append(grid_out_pred)
        task_attempt[f"attempt_{attempt_id+1}"] = outputs
    predictions[task_id] = task_attempt

with open('submission.json', 'w') as f:
    json.dump(predictions, f)
