In [None]:
import json
from dataclasses import dataclass
from datetime import datetime
import os
import uuid

import jax
import jax.numpy as jnp
import optax
import wandb

@dataclass(frozen=True)
class Hyperparams:
    seed: int = int(os.environ.get("SEED", 42))
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # ---
    learning_rate: float = 1e-3
    batch_size: int = 8
    num_epochs: int = 1
    print_every: int = 1
    pad_dim_space: int = 30
    pad_dim_time: int = 5
    pad_value: int = 0
    num_channels: int = 10
    augment_prob: float = 0.5

hp = Hyperparams()

def load_tasks(challenges_path, solutions_path):
    """ loads in raw dataset from json files, stored in RAM """
    tasks = []
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    for task_id in challenges_dict.keys():
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        # there may be multiple training pairs for each task
        for pair in challenges_dict[task_id]['train']:
            task_train_in.append(pair['input'])
            task_train_out.append(pair['output'])
        for grid in challenges_dict[task_id]['test']:
            task_eval_in.append(grid['input'])
        for grid in solutions_dict[task_id]:
            task_eval_out.append(grid)
        assert len(task_eval_in) == len(task_eval_out)
        assert len(task_train_in) == len(task_train_out)
        assert len(task_eval_in) == 1
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

def pad_space(key, grid_in, grid_out, hp: Hyperparams):
    # grids are [T, H, W, C], pick random start position in space based on grid shape
    assert grid_in.shape[1] <= hp.pad_dim_space
    assert grid_in.shape[2] <= hp.pad_dim_space
    assert grid_in.shape[1] == grid_in.shape[2]
    assert grid_out is None or grid_out.shape == grid_in.shape
    start_h = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid_in.shape[0] + 1)
    start_w = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid_in.shape[1] + 1)
    pad_width = ((start_h, hp.pad_dim_space - grid_in.shape[0] - start_h),
                 (start_w, hp.pad_dim_space - grid_in.shape[1] - start_w),
                 (0, 0))
    grid_in = jnp.pad(grid_in, pad_width, mode='constant', constant_values=hp.pad_value)
    if grid_out is not None:
        grid_out = jnp.pad(grid_out, pad_width, mode='constant', constant_values=hp.pad_value)
    return grid_in, grid_out, (start_h, start_w)

def pad_time(key, grid_in, grid_out, hp: Hyperparams):
    # grids are [T, H, W, C], repeat random example in time until T == hp.pad_dim_time
    assert grid_in.shape[0] <= hp.pad_dim_time
    assert grid_in.shape[1] <= hp.pad_dim_space
    assert grid_in.shape[2] <= hp.pad_dim_space
    assert grid_in.shape[1] == grid_in.shape[2]
    assert grid_out is None or grid_out.shape == grid_in.shape
    for _ in range(hp.pad_dim_time - grid_in.shape[0]):
        repeat_example_idx = jax.random.randint(key, minval=0, maxval=grid_in.shape[0])
        grid_in = jnp.concatenate([grid_in, grid_in[repeat_example_idx][None]], axis=0)
        if grid_out is not None:
            grid_out = jnp.concatenate([grid_out, grid_out[repeat_example_idx][None]], axis=0)
    # shuffle the T dimension
    perm = jax.random.permutation(key, hp.pad_dim_time)
    grid_in = grid_in[perm]
    if grid_out is not None:
        grid_out = grid_out[perm]
    return grid_in, grid_out

augmentations = [jnp.fliplr, jnp.flipud, jnp.rot90]

def random_augment(key, grid_in, grid_out, hp: Hyperparams):
    assert grid_in.shape == grid_out.shape
    assert grid_in.shape == (hp.pad_dim_time, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
    apply = jax.random.bernoulli(key, hp.augment_prob, shape=(len(augmentations),))
    for i, augmentation in enumerate(augmentations):
        grid_in = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_in)
        if grid_out is not None:
            grid_out = jax.lax.cond(apply[i], augmentation, lambda x: x, grid_out)
    # shuffle the C dimension
    perm = jax.random.permutation(key, hp.num_channels)
    grid_in = grid_in[..., perm]
    if grid_out is not None:
        grid_out = grid_out[..., perm]
    return grid_in, grid_out

def data_loader(key, tasks, hp: Hyperparams, train_mode=True):
    num_tasks = len(tasks)
    if train_mode:
        indices = jax.random.permutation(key, num_tasks)
        tasks = [tasks[idx] for idx in indices]
    num_batches = num_tasks // hp.batch_size
    key_seq = jax.random.split(key, num_batches)
    for i in range(num_batches):
        batch_key = key_seq[i]
        subkeys = jax.random.split(batch_key, hp.batch_size * 5)
        batch_train_in = []
        batch_train_out = []
        batch_eval_in = []
        batch_eval_out = []
        for j, subkey in zip(range(i*hp.batch_size, (i+1)*hp.batch_size), subkeys.reshape(hp.batch_size, 5)):
            k1, k2, k3, k4, k5 = subkey
            task_train_in = jnp.array(tasks[j][0], dtype=jnp.int32)
            task_train_out = jnp.array(tasks[j][1], dtype=jnp.int32)
            task_eval_in = jnp.array(tasks[j][2], dtype=jnp.int32)
            task_eval_out = jnp.array(tasks[j][3], dtype=jnp.int32)
            # convert grids to one-hot (channel dimmension)
            task_train_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_in)
            task_train_out = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_out)
            task_eval_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_eval_in)
            task_eval_out = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_eval_out)
            # pad in time and space
            task_train_in, task_train_out, _ = jax.vmap(pad_space, in_axes=(None, 0, 0, None))(k1, task_train_in, task_train_out, hp)
            task_train_in, task_train_out = pad_time(k2, task_train_in, task_train_out, hp)
            task_eval_in, task_eval_out, _ = jax.vmap(pad_space, in_axes=(None, 0, 0, None))(k3, task_eval_in, task_eval_out, hp)
            if train_mode:
                task_train_in, task_train_out = jax.vmap(random_augment, in_axes=(None, 0, 0, None))(k4, task_train_in, task_train_out, hp)
                task_eval_in, task_eval_out = jax.vmap(random_augment, in_axes=(None, 0, 0, None))(k5, task_eval_in, task_eval_out, hp)
            batch_train_in.append(task_train_in)
            batch_train_out.append(task_train_out)
            batch_eval_in.append(task_eval_in)
            batch_eval_out.append(task_eval_out)
        task_train_in = jnp.stack(batch_train_in)
        assert task_train_in.shape == (hp.batch_size, hp.pad_dim_time, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        task_train_out = jnp.stack(batch_train_out)
        assert task_train_out.shape == (hp.batch_size, hp.pad_dim_time, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        task_eval_in = jnp.stack(batch_eval_in)
        assert task_eval_in.shape == (hp.batch_size, 1, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        task_eval_out = jnp.stack(batch_eval_out)
        assert task_eval_out.shape == (hp.batch_size, 1, hp.pad_dim_space, hp.pad_dim_space, hp.num_channels)
        yield task_train_in, task_train_out, task_eval_in, task_eval_out

def init_params(key, hp: Hyperparams):
    # TODO
    return params

def model(params, task_train_in, task_train_out, task_eval_in, hp: Hyperparams):
    # TODO
    return task_eval_out

def loss_fn(task_eval_out_pred, task_eval_out_targ, hp: Hyperparams):
    # TODO
    return loss

def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp: Hyperparams):
    def loss_and_grad(params):
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(params, task_eval_out_pred, task_eval_out.squeeze(1), hp)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

def valid_step(params, valid_gen, hp: Hyperparams):
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    for task_train_in, task_train_out, task_eval_in, task_eval_out in valid_gen:
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        loss = loss_fn(task_eval_out_pred, task_eval_out, hp)
        total_loss += loss
        pred_classes = jnp.argmax(task_eval_out_pred, axis=-1)
        true_classes = jnp.argmax(task_eval_out.squeeze(1), axis=-1)
        acc = jnp.mean(pred_classes == true_classes)
        total_acc += acc
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches
    return avg_loss, avg_acc

print(f"hyperparameters: {hp}")
wandb.login()
wandb.init(
    entity=hp.wandb_entity,
    project=hp.wandb_project,
    name=f"{hp.morph}.{hp.compute_backend}.{str(uuid.uuid4())[:6]}",
    config=hp.__dict__,
)
key = jax.random.PRNGKey(hp.seed)
train_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json',
)
valid_tasks = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json',
)
params = init_params(key, hp)
optimizer = optax.adam(hp.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(hp.num_epochs):
    print(f"epoch {epoch + 1}/{hp.num_epochs}")
    steps_per_epoch = len(train_tasks) // hp.batch_size
    train_gen = data_loader(key, train_tasks, hp, train_mode=True)
    for step in range(steps_per_epoch):
        task_train_in, task_train_out, task_eval_in, task_eval_out = next(train_gen)
        params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp)
        if step % hp.print_every == 0:
            print(f"step {step + 1}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
            wandb.log({"train_loss": train_loss.item()}, step=step + 1)
    valid_gen = data_loader(key, valid_tasks, hp, train_mode=False)
    valid_loss, valid_acc = valid_step(params, valid_gen, hp)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=(epoch + 1) * steps_per_epoch)
wandb.finish()

submission_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
predictions = {}
with open(submission_challenges_path, 'r') as f:
    challenges_dict = json.load(f)
print(f"loading challenges from {submission_challenges_path}, found {len(challenges_dict)} challenges")
for task_id in challenges_dict.keys():
    task = challenges_dict[task_id]
    task_train_in = []
    task_train_out = []
    task_eval_in = []
    for pair in task['train']:
        task_train_in.append(pair['input'])
        task_train_out.append(pair['output'])
    for grid in task['test']:
        task_eval_in.append(grid['input'])
    task_eval_in = jnp.array(task_eval_in, dtype=jnp.int32)
    task_train_in = jnp.array(task_train_in, dtype=jnp.int32)
    task_train_out = jnp.array(task_train_out, dtype=jnp.int32)
    task_train_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_in)
    task_train_out = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_train_out)
    task_eval_in = jax.vmap(lambda x: jax.nn.one_hot(x, hp.num_channels))(task_eval_in)
    task_attempt = {}
    for attempt_id in range(2):
        # use different padding for each attempt
        key = jax.random.PRNGKey(attempt_id)
        task_train_in, task_train_out = jax.vmap(pad_space, in_axes=(None, 0, 0, None))(key, task_train_in, task_train_out, hp)
        task_train_in, task_train_out = pad_time(key, task_train_in, task_train_out, hp)
        task_eval_in, _, (start_h, start_w) = jax.vmap(pad_space, in_axes=(None, 0, 0, None))(key, task_eval_in, None, hp)
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, hp)
        # un-pad for submission
        grid_out_pred = task_eval_out_pred[0, start_h:start_h+task_eval_in.shape[1], start_w:start_w+task_eval_in.shape[2]]
        task_attempt[f"attempt_{attempt_id}"] = grid_out_pred.tolist()
    predictions[task_id] = task_attempt

with open('submission.json', 'w') as f:
    json.dump(predictions, f)
