In [None]:
import json
from dataclasses import dataclass
from datetime import datetime
import os
import uuid

import jax
import jax.numpy as jnp
from jax import random
import optax
import numpy as np
import wandb

@dataclass(frozen=True)
class Hyperparams:
    seed: int = os.environ.get("SEED", 42)
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    run_name: str = "test"
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # ---
    learning_rate: float = 1e-3
    batch_size: int = 64
    num_epochs: int = 8
    print_every: int = 100
    pad_dim_space: int = 30
    pad_dim_time: int = 5
    pad_value: int = 0
    augment_prob: float = 0.5

hp = Hyperparams()
hp.run_name = f"{hp.morph}.{hp.compute_backend}.{str(uuid.uuid4())[:6]}"

def load_tasks(challenges_path, solutions_path):
    """ loads in raw dataset from json files, stored in RAM """
    tasks = []
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    for task_id in challenges_dict.keys():
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        for pair in challenges_dict[task_id]['train']:
            task_train_in.append(pair['input'])
            task_train_out.append(pair['output'])
        for grid in challenges_dict[task_id]['test']:
            task_eval_in.append(grid)
        for grid in solutions_dict[task_id]:
            task_eval_out.append(grid)
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

@jax.jit
def pad_space(key, grid, hp: Hyperparams):
    # grid is [H, W, C], pick random start position in space based on grid shape
    start_h = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid.shape[0])
    start_w = jax.random.randint(key, minval=0, maxval=hp.pad_dim_space - grid.shape[1])
    return jnp.pad(grid, (
        (start_h, hp.pad_dim_space - grid.shape[0] - start_h),
        (start_w, hp.pad_dim_space - grid.shape[1] - start_w),
        (0, 0)), mode='constant', constant_values=hp.pad_value)

@jax.jit
def pad_time(key, grid, hp: Hyperparams):
    # grid is [H, W, C], pick random start position in time based on grid shape
    start_t = jax.random.randint(key, minval=0, maxval=hp.pad_dim_time - grid.shape[0])
    return jnp.pad(grid, (
        (start_t, hp.pad_dim_time - grid.shape[0] - start_t),
        (0, 0),
        (0, 0)), mode='constant', constant_values=hp.pad_value)

@jax.jit
def random_augment(key, grid, hp: Hyperparams):
    # rotate and flip in the H and W dimmensions
    for aug_func in [jnp.fliplr, jnp.flipud, jnp.rot90]:
        _key, _ = jax.random.split(key)
        do_aug = jax.random.bernoulli(_key, hp.augment_prob)
        grid = jax.lax.cond(do_aug, lambda x: aug_func(x), lambda x: x, grid)
    # shuffle the C dimmension
    grid = jax.random.permutation(key, grid, axis=2)
    return grid

@jax.jit
def data_loader(key, tasks, hp: Hyperparams, train_mode=True):
    num_tasks = tasks.shape[0]
    if train_mode:
        indices = jax.random.permutation(key, num_tasks)
        tasks = tasks[indices]
    num_batches = num_tasks // hp.batch_size
    for i in range(num_batches):
        for j in range(i*hp.batch_size,(i+1)*hp.batch_size):
            task_train_in, task_train_out, task_eval_in, task_eval_out = tasks[j]
            # load into gpu memory
            task_train_in = jnp.array(task_train_in)
            task_train_out = jnp.array(task_train_out)
            task_eval_in = jnp.array(task_eval_in)
            task_eval_out = jnp.array(task_eval_out)
            # pad the task grids in space
            keys = jax.random.split(key, 4)
            task_train_in = jax.vmap(pad_space)(keys, task_train_in)
            task_train_out = jax.vmap(pad_space)(keys, task_train_out)
            task_eval_in = jax.vmap(pad_space)(keys, task_eval_in)
            task_eval_out = jax.vmap(pad_space)(keys, task_eval_out)
            # pad the task_train grids in time
            task_train_in = jnp.pad(task_train_in, ((0, 0), (0, hp.pad_dim_time), (0, 0)), mode='constant', constant_values=hp.pad_value)



            # augmentation
            if train_mode:
                keys = jax.random.split(key, 
                augmented_batch = jax.vmap(random_augment)(keys, batch)

        yield task_train_in, task_train_out, task_eval_in, task_eval_out

def init_params(key, hp: Hyperparams):
    # TODO: implement
    pass

def model(params, task_train_in, task_train_out, task_eval_in, hp: Hyperparams):
    # TODO: implement
    pass

def loss_fn(params, task_eval_out_pred, task_eval_out_targ, hp: Hyperparams):
    # TODO: implement
    pass

def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp: Hyperparams):
    # TODO: implement
    pass

def valid_step(params, task_eval_in, task_eval_out, hp: Hyperparams):
    # TODO: implement
    pass

print(f"training run {hp.run_name} with hyperparameters: {hp}")
wandb.login()
wandb.init(name=hp.wandb_entity, project=hp.wandb_project, config=hp.__dict__)
key = jax.random.PRNGKey(hp.seed)
train_dataset = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json',
)
valid_dataset = load_tasks(
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json',
    '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json',
)
train_gen = data_loader(key, train_dataset, hp, train_mode=True)
valid_gen = data_loader(key, valid_dataset, hp, train_mode=False)
params = init_params(key, hp)
optimizer = optax.adam(hp.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(hp.num_epochs):
    print(f"epoch {epoch + 1}/{hp.num_epochs}")
    steps_per_epoch = len(train_dataset) // hp.batch_size
    step = 0
    for step in range(steps_per_epoch):
        step += epoch * steps_per_epoch + step + 1
        task_train_in, task_train_out, task_eval_in, task_eval_out = next(train_gen)
        params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, hp)
        if step % hp.print_every == 0:
            print(f"step {step}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
            wandb.log({"train_loss": train_loss.item()}, step=step)
    valid_loss, valid_acc = valid_step(params, valid_gen, hp)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=step)
wandb.finish()



submission_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
tasks = []
with open(submission_challenges_path, 'r') as f:
    challenges_dict = json.load(f)
print(f"loading challenges from {submission_challenges_path}, found {len(challenges_dict)} challenges")
for task_id in challenges_dict.keys():
    # TODO: implement
    pass

with open('submission.json', 'w') as f:
    json.dump(predictions, f)