In [None]:
from dataclasses import dataclass
from datetime import datetime
import json
import os
import pickle
import uuid

import numpy as np

#<config>
@dataclass(frozen=True)
class Config:
    seed: int = int(os.environ.get("SEED", 0))
    # --- data
    data_seed: int = 42
    train_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
    train_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
    valid_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
    valid_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'
    submission_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
    num_order_augs: int = 2 # number of task train grid order augmentations
    num_color_augs: int = 4 # number of task grid color augmentations (maintains matching train and test pair)
    # --- logging 
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = "hug"
    wandb_project: str = "arc-test"
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # --- model
    # TODO: add model hyperparameters here
    # --- training
    num_epochs: int = 100 # number of epochs to train
    batch_size: int = 32 
    print_every: int = 10 # print training loss every this many steps
    early_stopping_patience: int = 10 # stop training if no improvement for this many epochs
    learning_rate: float = 1e-3 # initial learning rate
    lr_patience: int = 5 # number of epochs with no improvement after which learning rate will be reduced
    lr_cooldown: int = 0 # number of epochs to wait before resuming normal operation after the learning rate reduction
    lr_factor: float = 0.5 # factor by which to reduce the learning rate
    lr_rtol: float = 1e-4  # relative tolerance for measuring the new optimum
    lr_accumulation_size: int = 200 # number of iterations to accumulate an average value
#</config>

cfg = Config()

if cfg.compute_backend == "kaggle":
    # when submitting to kaggle, save the output to the current directory
    output_dir = os.getcwd()
else:
    output_dir = f"/arcnca/output/{cfg.morph}"
    os.makedirs(output_dir, exist_ok=True)

print(f"output_dir: {output_dir}")
print(f"config:{json.dumps(cfg.__dict__, indent=4)}")
config_filepath = os.path.join(output_dir, "config.json")
with open(config_filepath, 'w') as f:
    json.dump(cfg.__dict__, f, indent=4)

if not cfg.compute_backend == "kaggle":
    import uuid
    import wandb
    wandb.login()
    wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=f"{cfg.compute_backend}.{cfg.morph}.{str(uuid.uuid4())[:6]}", config=cfg.__dict__)
    wandb.save(config_filepath)

def save_checkpoint(params, filename):
    with open(os.path.join(output_dir, filename), 'wb') as f:
        pickle.dump(params, f)

def load_checkpoint(filename):
    with open(os.path.join(output_dir, filename), 'rb') as f:
        return pickle.load(f)

def load_tasks(challenges_path: str, solutions_path: str, cfg: Config):
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    if solutions_path is not None:
        with open(solutions_path, 'r') as f:
            solutions_dict = json.load(f)
        print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    """
    tasks are stored in JSON format. Each JSON file consists of two key-value pairs.
    train: a list of two to ten input/output pairs (typically three.) These are used for your algorithm to infer a rule.
    test: a list of one to three input/output pairs (typically one.) Your model should apply the inferred rule from the train set and construct an output solution.
    """
    tasks = []
    for task_id in challenges_dict.keys():
        train_in = []
        train_out = []
        test_in = []
        test_out = []
        """
        a "grid" is a rectangular matrix (list of lists) of integers between 0 and 9 (inclusive).
        the smallest possible grid size is 1x1 and the largest is 30x30.
        0 represents the background color, 1-9 represent the pattern colors.
        """
        for pair in challenges_dict[task_id]['train']:
            train_in.append(np.array(pair['input'], dtype=np.uint8))
            train_out.append(np.array(pair['output'], dtype=np.uint8))
        for grid in challenges_dict[task_id]['test']:
            test_in.append(np.array(grid['input'], dtype=np.uint8))
        if solutions_path is not None:
            for grid in solutions_dict[task_id]:
                test_out.append(np.array(grid, dtype=np.uint8))
        tasks.append((task_id, train_in, train_out, test_in, test_out))
    return tasks

def augmentation(tasks, cfg: Config):
    np.random.seed(cfg.data_seed)
    augmented_tasks = []
    for task in tasks:
        task_id, train_in, train_out, eval_in, eval_out = task
        # grid structure means we can use symmetry to augment the tasks
        for aug in [
            np.fliplr,
            np.flipud,
            lambda x: np.rot90(x, k=1),
            lambda x: np.rot90(x, k=3)
        ]:
            augmented_tasks.append((
                f"{task_id}.{str(uuid.uuid4())[:6]}",
                [aug(grid) for grid in train_in],
                [aug(grid) for grid in train_out],
                [aug(grid) for grid in eval_in],
                [aug(grid) for grid in eval_out]
            ))
    tasks += augmented_tasks
    print(f"after spatial augmentation, tasks count: {len(tasks)}")
    # assume order of train grids is valid
    for _ in range(cfg.num_order_augs):
        augmented_tasks = []
        for task in tasks:
            train_order = np.random.permutation(len(train_in))
            augmented_tasks.append((
                f"{task_id}.{str(uuid.uuid4())[:6]}",
                [train_in[i] for i in train_order],
                [train_out[i] for i in train_order],
                eval_in,
                eval_out
            ))
        tasks += augmented_tasks
    print(f"after order augmentation x{cfg.num_order_augs}, tasks count: {len(tasks)}")
    # all colors (except for background) are interchangeable (but must match entire set)
    for _ in range(cfg.num_color_augs):
        augmented_tasks = []
        for task in tasks:
            color_map = np.arange(10)
            color_map[1:] = np.random.permutation(color_map[1:])
            augmented_tasks.append((
                f"{task_id}.{str(uuid.uuid4())[:6]}",
                [np.take(color_map, grid) for grid in train_in],
                [np.take(color_map, grid) for grid in train_out],
                [np.take(color_map, grid) for grid in eval_in],
                [np.take(color_map, grid) for grid in eval_out]
            ))
        tasks += augmented_tasks
    print(f"after color augmentation x{cfg.num_color_augs}, tasks count: {len(tasks)}")
    return augmented_tasks

train_tasks = load_tasks(cfg.train_challenges, cfg.train_solutions, cfg)
train_tasks = augmentation(train_tasks, cfg)
valid_tasks = load_tasks(cfg.valid_challenges, cfg.valid_solutions, cfg)
submission_tasks = load_tasks(cfg.submission_challenges, None, cfg)


#<model>
#<model>
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial

def process_grids(grids, max_num_pairs, grid_size=30):
    """
    Pads and flattens a list of grids to a fixed size and number of pairs.
    """
    grids_padded = []
    for grid in grids:
        h, w = grid.shape
        pad_h = grid_size - h
        pad_w = grid_size - w
        grid_padded = np.pad(grid, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)
        grids_padded.append(grid_padded.flatten())
    # Pad the number of grids
    num_grids = len(grids_padded)
    if num_grids < max_num_pairs:
        grids_padded.extend([np.zeros(grid_size * grid_size, dtype=np.uint8)] * (max_num_pairs - num_grids))
    return np.concatenate(grids_padded)  # Shape: (max_num_pairs * grid_size * grid_size,)

def init_params(key, cfg: Config):
    """
    Initializes the parameters of the MLP.
    """
    keys = jax.random.split(key, 3)
    input_size = (10 + 10 + 3) * 30 * 30  # train_in, train_out, test_in
    hidden_size = 32
    output_size = 3 * 30 * 30 * 10  # test_out pixels with 10 classes

    glorot_init = jax.nn.initializers.glorot_uniform()
    params = {
        'W1': glorot_init(keys[0], (input_size, hidden_size)),
        'b1': jnp.zeros(hidden_size),
        'W2': glorot_init(keys[1], (hidden_size, hidden_size)),
        'b2': jnp.zeros(hidden_size),
        'W3': glorot_init(keys[2], (hidden_size, output_size)),
        'b3': jnp.zeros(output_size)
    }
    return params

def model(params, batch_inputs, cfg: Config):
    """
    Forward pass of the MLP model.
    """
    h1 = jax.nn.relu(jnp.dot(batch_inputs, params['W1']) + params['b1'])
    h2 = jax.nn.relu(jnp.dot(h1, params['W2']) + params['b2'])
    logits = jnp.dot(h2, params['W3']) + params['b3']
    logits = logits.reshape(-1, 2700, 10)  # (batch_size, pixels, classes)
    return logits

def loss_fn(test_out_predicted, test_out_target, cfg: Config):
    """
    Computes the cross-entropy loss.
    """
    losses = optax.softmax_cross_entropy_with_integer_labels(test_out_predicted, test_out_target)
    loss = losses.mean()
    return loss

def accuracy_fn(test_out_predicted, test_out_target, cfg: Config):
    """
    Computes the accuracy, considering only exact matches as correct.
    """
    predictions = jnp.argmax(test_out_predicted, axis=-1)  # Shape: (batch_size, 2700)
    correct = jnp.equal(predictions, test_out_target)
    accuracies = jnp.all(correct, axis=-1).astype(jnp.float32)
    accuracy = accuracies.mean()
    return accuracy

def make_dataloader(tasks, cfg: Config, train_mode=True):
    """
    Creates a data loader that yields batches of processed inputs and targets.
    """
    num_tasks = len(tasks)
    indices = np.arange(num_tasks)
    if train_mode:
        np.random.shuffle(indices)
    batch_size = cfg.batch_size
    num_batches = num_tasks // batch_size
    for i in range(num_batches):
        batch_indices = indices[i * batch_size: (i + 1) * batch_size]
        batch_input_vectors = []
        batch_labels = []
        for idx in batch_indices:
            task_id, train_in, train_out, test_in, test_out = tasks[idx]
            # Process and concatenate inputs
            train_in_vector = process_grids(train_in, max_num_pairs=10, grid_size=30)
            train_out_vector = process_grids(train_out, max_num_pairs=10, grid_size=30)
            test_in_vector = process_grids(test_in, max_num_pairs=3, grid_size=30)
            input_vector = np.concatenate([train_in_vector, train_out_vector, test_in_vector])
            batch_input_vectors.append(input_vector)
            # Process labels
            if test_out:
                test_out_vector = process_grids(test_out, max_num_pairs=3, grid_size=30)
                batch_labels.append(test_out_vector)
        batch_inputs = jnp.array(batch_input_vectors)
        if batch_labels:
            batch_labels = jnp.array(batch_labels)
        else:
            batch_labels = None
        yield batch_inputs, batch_labels
#</model>

#<training>
import optax
from optax import contrib
from optax import tree_utils as otu

key = jax.random.PRNGKey(cfg.seed)
params = init_params(key, cfg)
opt = optax.chain(
    optax.adam(cfg.learning_rate),
    optax.contrib.reduce_on_plateau(
        patience=cfg.lr_patience,
        cooldown=cfg.lr_cooldown,
        factor=cfg.lr_factor,
        rtol=cfg.lr_rtol,
        accumulation_size=cfg.lr_accumulation_size,
    ),
)
opt_state = opt.init(params)

def train_step(params, opt_state, train_in, train_out, test_in, test_out, cfg: Config):
    def loss_and_grad(params):
        test_out_predicted = model(params, train_in, train_out, test_in, cfg)
        loss = loss_fn(test_out_predicted, test_out, cfg)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = opt.update(grads, opt_state, params, value=loss)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

def valid_step(params, valid_gen, cfg: Config):
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    for train_in, train_out, test_in, test_out in valid_gen:
        test_out_pred = model(params, train_in, train_out, test_in, cfg)
        loss = loss_fn(test_out_pred, test_out, cfg)
        total_loss += loss
        total_acc += accuracy_fn(test_out_pred, test_out, cfg)
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches
    return avg_loss, avg_acc

global_step = 0
best_valid_loss = float('inf')
epochs_without_improvement = 0
for epoch in range(cfg.num_epochs):
    print(f"epoch {epoch + 1}/{cfg.num_epochs}")
    steps_per_epoch = len(train_tasks) // cfg.batch_size
    # training
    train_gen = make_dataloader(train_tasks, cfg, train_mode=True)
    for step in range(steps_per_epoch):
        global_step += 1
        train_in, train_out, test_in, test_out = next(train_gen)
        params, opt_state, train_loss = train_step(params, opt_state, train_in, train_out, test_in, test_out, cfg)
        if step % cfg.print_every == 0:
            print(f"global step {global_step} epoch step {step}/{steps_per_epoch} loss = {train_loss.item():.4f}")
            if not cfg.compute_backend == "kaggle":
                wandb.log({"train_loss": train_loss.item()}, step=global_step)
    # validation
    valid_gen = make_dataloader(valid_tasks, cfg, train_mode=False)
    valid_loss, valid_acc = valid_step(params, valid_gen, cfg)
    print(f'valid_loss: {valid_loss.item():.4f}, valid_acc: {valid_acc.item():.4f}')
    if not cfg.compute_backend == "kaggle":
        wandb.log({
            "valid_loss": valid_loss.item(),
            "valid_acc": valid_acc.item(),
            "lr_scale" : otu.tree_get(opt_state, "scale")
            }, step=global_step)
    # early stopping
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        epochs_without_improvement = 0
        save_checkpoint(params, "best.pkl")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= cfg.early_stopping_patience:
            print(f"early stopping at epoch {epoch + 1}")
            break
save_checkpoint(params, "final.pkl")
# submission will be made with two model checkpoints
attempt1_ckpt = "best.pkl"
attempt2_ckpt = "final.pkl"
#</training>

"""
for each task output in the evaluation set, you should make exactly 2 predictions (attempt_1, attempt_2).
most tasks only have a single output (a single dictionary enclosed in a list), although some tasks have multiple outputs that must be predicted.
when a task has multiple test outputs that need to be predicted, they must be in the same order as the corresponding test inputs.
"""
predictions = {}
for i, ckpt in enumerate([attempt1_ckpt, attempt2_ckpt]):
    params = load_checkpoint(ckpt)
    submission_gen = make_dataloader(submission_tasks, cfg, train_mode=False)
    for task_id, train_in, train_out, test_in, _ in submission_gen:
        test_out_pred = model(params, train_in, train_out, test_in, cfg)
        if task_id not in predictions:
            predictions[task_id] = []
        predictions[task_id].append({f"attempt_{i+1}" : test_out_pred.tolist()})
submission_filepath = os.path.join(output_dir, "submission.json")

with open(submission_filepath, 'w') as f:
    json.dump(predictions, f)

results = {"accuracy": valid_acc.item()}
results_filepath = os.path.join(output_dir, "results.json")

with open(results_filepath, 'w') as f:
    json.dump(results, f, indent=4)

if not cfg.compute_backend == "kaggle":
    wandb.save(submission_filepath)
    wandb.save(results_filepath)
    wandb.finish()