In [None]:
import json
from dataclasses import dataclass
from datetime import datetime
import os
import uuid

import jax
import jax.numpy as jnp
import optax
import wandb

@dataclass(frozen=True)
class Hyperparams:
    seed: int = os.environ.get("SEED", 42)
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    run_name: str = "test"
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # ---
    learning_rate: float = 1e-3
    batch_size: int = 64
    num_epochs: int = 8
    print_every: int = 100
    pad_dim_space: int = 30
    pad_dim_time: int = 5
    pad_value: int = 0
    num_channels: int = 9
    augment_prob: float = 0.5

hp = Hyperparams()
hp.run_name = f"{hp.morph}.{hp.compute_backend}.{str(uuid.uuid4())[:6]}"

def load_tasks(challenges_path, solutions_path):
    """ loads in raw dataset from json files, stored in RAM """
    tasks = []
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    for task_id in challenges_dict.keys():
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        # there may be multiple training pairs for each task
        for pair in challenges_dict[task_id]['train']:
            task_train_in.append(pair['input'])
            task_train_out.append(pair['output'])
        for grid in challenges_dict[task_id]['test']:
            task_eval_in.append(grid)
        for grid in solutions_dict[task_id]:
            task_eval_out.append(grid)
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

@jax.jit
def pad_space(key, grid_in, grid_out, hp: Hyperparams):
    # grid is [T, H, W, C], pick random start position in space based on grid shape
    start_h = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid_in.shape[1])
    start_w = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid_in.shape[2])
    pad_width = ((0, 0),
        (start_h, hp.pad_dim_space - grid_in.shape[1] - start_h),
        (start_w, hp.pad_dim_space - grid_in.shape[2] - start_w),
        (0, 0))
    grid_in = jnp.pad(grid_in, pad_width, mode='constant', constant_values=hp.pad_value)
    grid_out = jnp.pad(grid_out, pad_width, mode='constant', constant_values=hp.pad_value)
    return grid_in, grid_out

@jax.jit
def pad_time(key, grid_in, grid_out, hp: Hyperparams):
    # grid is [T, H, W, C], pick random start position in space based on grid shape
    # grids are [T, H, W, C], pick random T indices to pad in time
    for _ in range(hp.pad_dim_time - grid_in.shape[0]):
        repeat_example_idx = jax.random.randint(key, minval=0, maxval=grid_in.shape[0])
        grid_in = jnp.concatenate([grid_in, grid_in[repeat_example_idx][None]], axis=0)
        grid_out = jnp.concatenate([grid_out, grid_out[repeat_example_idx][None]], axis=0)
    # shuffle the T dimmension
    perm = jax.random.permutation(key, jnp.arange(hp.pad_dim_time))
    grid_in = grid_in[perm]
    grid_out = grid_out[perm]
    return grid_in, grid_out

augmentations = [jnp.fliplr, jnp.flipud, jnp.rot90]

@jax.jit
def random_augment(key, grid_in, grid_out, hp: Hyperparams):
    # rotate and flip in the H and W dimmensions
    apply = jax.random.bernoulli(key, hp.augment_prob, shape=(len(augmentations),))
    for augmentation in augmentations:
        # apply augmentations on H and W dimmensions only
        # use jax.lax.cond
        grid_in = jax.lax.cond(apply, augmentation, lambda x: x, grid_in)
        grid_out = jax.lax.cond(apply, augmentation, lambda x: x, grid_out)
    # shuffle the C dimmension
    perm = jax.random.permutation(key, jnp.arange(hp.num_channels))
    grid_in = grid_in[..., perm]
    grid_out = grid_out[..., perm]
    return grid_in, grid_out

@jax.jit
def data_loader(key, tasks, hp: Hyperparams, train_mode=True):
    num_tasks = tasks.shape[0]
    if train_mode:
        indices = jax.random.permutation(key, num_tasks)
        tasks = tasks[indices]
    num_batches = num_tasks // hp.batch_size
    for i in range(num_batches):
        batch_train_in = []
        batch_train_out = []
        batch_eval_in = []
        batch_eval_out = []
        for j in range(i*hp.batch_size,(i+1)*hp.batch_size):
            # load into gpu memory, expand channel dimmension based on integer value
            task_train_in = jnp.array(tasks[j][0], dtype=jnp.int32)
            task_train_out = jnp.array(tasks[j][1], dtype=jnp.int32)
            task_eval_in = jnp.array(tasks[j][2], dtype=jnp.int32)
            task_eval_out = jnp.array(tasks[j][3], dtype=jnp.int32)
            # pad space and time for task_train grids
            task_train_in, task_train_out = jax.vmap(pad_space)(key, task_train_in, task_train_out, hp)
            task_train_in, task_train_out = jax.vmap(pad_time)(key, task_train_in, task_train_out, hp)
            # pad space for task_eval grids
            task_eval_in, task_eval_out = jax.vmap(pad_space)(key, task_eval_in, task_eval_out, hp)
            if train_mode:
                task_train_in, task_train_out = jax.vmap(random_augment)(key, task_train_in, task_train_out, hp)
                task_eval_in, task_eval_out = jax.vmap(random_augment)(key, task_eval_in, task_eval_out, hp)
            batch_train_in.append(task_train_in)
            batch_train_out.append(task_train_out)
            batch_eval_in.append(task_eval_in)
            batch_eval_out.append(task_eval_out)
        # stack into batches
        task_train_in = jnp.stack(batch_train_in) # [batch_size, pad_dim_time, pad_dim_space, pad_dim_space, num_channels]
        task_train_out = jnp.stack(batch_train_out) # [batch_size, pad_dim_time, pad_dim_space, pad_dim_space, num_channels]
        task_eval_in = jnp.stack(batch_eval_in) # [batch_size, 1, pad_dim_space, pad_dim_space, num_channels]
        task_eval_out = jnp.stack(batch_eval_out) # [batch_size, 1, pad_dim_space, pad_dim_space, num_channels]
        yield task_train_in, task_train_out, task_eval_in, task_eval_out

def init_params(key, hp: Hyperparams):
    # TODO: implement
    pass

def model(params, task_train_in, task_train_out, task_eval_in, hp: Hyperparams):
    # TODO: implement
    pass

def loss_fn(params, task_eval_out_pred, task_eval_out_targ, hp: Hyperparams):
    # TODO: implement
    pass

def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp: Hyperparams):
    # TODO: implement
    pass

def valid_step(params, task_eval_in, task_eval_out, hp: Hyperparams):
    # TODO: implement
    pass

print(f"training run {hp.run_name} with hyperparameters: {hp}")
wandb.login()
wandb.init(name=hp.wandb_entity, project=hp.wandb_project, config=hp.__dict__)
key = jax.random.PRNGKey(hp.seed)
train_dataset = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json',
)
valid_dataset = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json',
)
train_gen = data_loader(key, train_dataset, hp, train_mode=True)
valid_gen = data_loader(key, valid_dataset, hp, train_mode=False)
params = init_params(key, hp)
optimizer = optax.adam(hp.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(hp.num_epochs):
    print(f"epoch {epoch + 1}/{hp.num_epochs}")
    steps_per_epoch = len(train_dataset) // hp.batch_size
    step = 0
    for step in range(steps_per_epoch):
        step += epoch * steps_per_epoch + step + 1
        task_train_in, task_train_out, task_eval_in, task_eval_out = next(train_gen)
        params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp)
        if step % hp.print_every == 0:
            print(f"step {step}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
            wandb.log({"train_loss": train_loss.item()}, step=step)
    valid_loss, valid_acc = valid_step(params, valid_gen, hp)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=step)
wandb.finish()



submission_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
tasks = []
with open(submission_challenges_path, 'r') as f:
    challenges_dict = json.load(f)
print(f"loading challenges from {submission_challenges_path}, found {len(challenges_dict)} challenges")
for task_id in challenges_dict.keys():
    # TODO: implement
    pass

with open('submission.json', 'w') as f:
    json.dump(predictions, f)