In [None]:
#<config>
from dataclasses import dataclass
from datetime import datetime
import os

@dataclass(frozen=True)
class Config:
    seed: int = int(os.environ.get("SEED", 42))
    # --- data
    train_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
    train_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
    valid_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
    valid_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'
    submission_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
    num_time_augs: int = 8
    num_color_augs: int = 32
    # --- logging 
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = os.environ.get("WANDB_ENTITY", "hug")
    wandb_project: str = os.environ.get("WANDB_PROJECT", "arc-test")
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # --- training
    learning_rate: float = 1e-3
    num_epochs: int = 100
    batch_size: int = 32
    print_every: int = 10
    # --- model

#<\config>

#<dataprep>
import json
import uuid
import numpy as np

def load_tasks(challenges_path: str, solutions_path: str, cfg: Config):
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    with open(solutions_path, 'r') as f:
        solutions_dict = json.load(f)
    print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    """
    tasks are stored in JSON format. Each JSON file consists of two key-value pairs.
    train: a list of two to ten input/output pairs (typically three.) These are used for your algorithm to infer a rule.
    test: a list of one to three input/output pairs (typically one.) Your model should apply the inferred rule from the train set and construct an output solution.
    """
    tasks = []
    for task_id in challenges_dict.keys():
        print(f"\t task {task_id}")
        task_train_in = []
        task_train_out = []
        task_eval_in = []
        task_eval_out = []
        """
        a "grid" is a rectangular matrix (list of lists) of integers between 0 and 9 (inclusive).
        the smallest possible grid size is 1x1 and the largest is 30x30.
        0 represents the background color, 1-9 represent the pattern colors.
        """
        for pair in challenges_dict[task_id]['train']:
            _task_train_in = np.array(pair['input'], dtype=np.uint8) # store as uint8 to save system memory
            _task_train_out = np.array(pair['output'], dtype=np.uint8)
            task_train_in.append(_task_train_in)
            task_train_out.append(_task_train_out)
        for grid in challenges_dict[task_id]['test']:
            _task_eval_in = np.array(grid['input'], dtype=np.uint8)
            task_eval_in.append(_task_eval_in)
        for grid in solutions_dict[task_id]:
            _grid = np.array(grid, dtype=np.uint8)
            task_eval_out.append(_grid)
        tasks.append((task_train_in, task_train_out, task_eval_in, task_eval_out))
    return tasks

def augmentation(tasks, cfg: Config):
    """
    basic "spatial" augmentation of grids: flipping (lr and ud), rotating (90 and 270)
    basic "time" augmentation of tasks: changing the order of the training pairs
    basic "channel" augmentation of grids: change the colors used in the grid (except for 0 the background color)
    by pre-augmenting the dataset, we increase the size of the dataset on system memory,
    since the dataset is small, the tradeoff of less gpu compute at train time is worth it.
    """
    np.random.seed(cfg.seed)
    augmented_tasks = []
    for task in tasks:
        train_in, train_out, eval_in, eval_out = task
        augs = [task]  # Start with the original task
        spatial_augs = [np.fliplr, np.flipud, lambda x: np.rot90(x, 1), lambda x: np.rot90(x, 3)]
        augs.extend([
            ([aug(grid) for grid in train_in],
             [aug(grid) for grid in train_out],
             [aug(grid) for grid in eval_in],
             [aug(grid) for grid in eval_out])
            for aug in spatial_augs
        ])
        if len(train_in) > 1:
            augs.extend([
                ([train_in[i] for i in np.random.permutation(len(train_in))],
                 [train_out[i] for i in np.random.permutation(len(train_out))],
                 eval_in, eval_out)
                for _ in range(cfg.num_time_augs)
            ])
        for _ in range(cfg.num_color_augs):
            color_map = np.arange(10)
            np.random.shuffle(color_map[1:])  # keep 0 as background color
            augs.append((
                [color_map[grid] for grid in train_in],
                [color_map[grid] for grid in train_out],
                [color_map[grid] for grid in eval_in],
                [color_map[grid] for grid in eval_out]
            ))
        augmented_tasks.extend(augs)
    return augmented_tasks

cfg = Config()

if cfg.compute_backend == "kaggle":
    # when submitting to kaggle, save the output to the current directory
    output_dir = os.getcwd()
else:
    output_dir = f"/arcnca/output{cfg.morph}"
    os.makedirs(output_dir, exist_ok=True)

hparams_filepath = os.path.join(output_dir, "config.json")
with open(hparams_filepath, 'w') as f:
    json.dump(cfg.__dict__, f, indent=4)

if not cfg.compute_backend == "kaggle":
    import wandb
    wandb.login()
    wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=f"{cfg.compute_backend}.{cfg.morph}.{str(uuid.uuid4())[:6]}", config=cfg.__dict__)
    wandb.save(hparams_filepath)

train_tasks = load_tasks(cfg.train_challenges, cfg.train_solutions, cfg)
train_tasks = augmentation(train_tasks, cfg)
valid_tasks = load_tasks(cfg.valid_challenges, cfg.valid_solutions, cfg)
#<\dataprep>

#<model>
import jax
import jax.numpy as jnp
import optax

def init_params(key, cfg: Config):
    # TODO: Implement this function
    return params

def model(params, task_train_in, task_train_out, task_eval_in, cfg: Config):
    # TODO: Implement this function
    pass

def loss_fn(task_eval_out_predicted, task_eval_out_target, cfg: Config):
    # TODO: Implement this function
    pass

def accuracy_fn(task_eval_out_predicted, task_eval_out_target, cfg: Config):
    # TODO: Implement this function
    pass

def make_dataloader(tasks, cfg: Config, train_mode=True):
    # TODO: Implement this function
    pass

def make_predictions(params, cfg):
    # TODO: Implement this function
    return {}
#<\model>

#<training>
def train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, cfg: Config):
    def loss_and_grad(params):
        task_eval_out_predicted = model(params, task_train_in, task_train_out, task_eval_in, cfg)
        loss = loss_fn(task_eval_out_predicted, task_eval_out, cfg)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

def valid_step(params, valid_gen, cfg: Config):
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    for task_train_in, task_train_out, task_eval_in, task_eval_out in valid_gen:
        task_eval_out_pred = model(params, task_train_in, task_train_out, task_eval_in, cfg)
        loss = loss_fn(task_eval_out_pred, task_eval_out, cfg)
        total_loss += loss
        total_acc += accuracy_fn(task_eval_out_pred, task_eval_out, cfg)
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches
    return avg_loss, avg_acc

model_cfg = Config()
key = jax.random.PRNGKey(cfg.seed)
params = init_params(key, model_cfg)
optimizer = optax.adam(model_cfg.learning_rate)
opt_state = optimizer.init(params)
for epoch in range(model_cfg.num_epochs):
    print(f"epoch {epoch + 1}/{cfg.num_epochs}")
    steps_per_epoch = int(jnp.ceil(len(train_tasks) / cfg.batch_size))
    # training
    train_gen = make_dataloader(train_tasks, cfg, train_mode=True)
    for step in range(steps_per_epoch):
        try:
            task_train_in, task_train_out, task_eval_in, task_eval_out = next(train_gen)
        except StopIteration:
            break
        params, opt_state, train_loss = train_step(params, opt_state, task_train_in, task_train_out, task_eval_in, task_eval_out, model_cfg)
        if step % cfg.print_every == 0:
            print(f"step {step + 1}/{steps_per_epoch}: loss = {train_loss.item():.4f}")
            if not cfg.compute_backend == "kaggle":
                wandb.log({"train_loss": train_loss.item()}, step=step + 1)
    # validation
    valid_gen = make_dataloader(valid_tasks, cfg, train_mode=False)
    valid_loss, valid_acc = valid_step(params, valid_gen, cfg)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    if not cfg.compute_backend == "kaggle":
        wandb.log({"valid_loss": valid_loss.item(), "valid_acc": valid_acc.item()}, step=(epoch + 1) * steps_per_epoch)
#<\training>

#<submission>
predictions = make_predictions(params, cfg)
submission_filepath = os.path.join(output_dir, "submission.json")
with open('submission.json', 'w') as f:
    json.dump(predictions, f)
results = {"accuracy": valid_acc.item()}
results_filepath = os.path.join(output_dir, "results.json")
with open(results_filepath, 'w') as f:
    json.dump(results, f, indent=4)
if not cfg.compute_backend == "kaggle":
    wandb.save(submission_filepath)
    wandb.save(results_filepath)
    wandb.finish()
#<\submission>