In [1]:
import typing
import lzma
import os
import dataclasses
import itertools
import functools
import json
import contextlib
import zipfile
from importlib import reload
from pathlib import Path
from types import MappingProxyType, SimpleNamespace

import cbor2
import attrs
import tqdm.auto
import jax
import jax.numpy as jnp
import etils.epath
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from arc25 import symmetry, tools as arc25_tools
from arc25.symmetry import D4, transform_vector
from arc25 import serialisation
from arc25.dsl.types import Vector, Dir4
from arc25.vision2.symrep import SymDecompBase, SplitSymDecomp, SymDecompDims, standard_rep, RepSpec
from arc25.vision2.fields import FieldDims, CoordinateGrid
from arc25.vision2.linear import SpaceSymmetricLinear, SpaceSymmetricTensor, SymmetryMappingSpec, SymDecompLinear
from arc25.vision2 import fields, attention, encoder, transformer, mae, swiglu, arc_solver
from arc25.training import saving, dataset, mae as mae_trainer, knn_eval, linear_probe, arc_solver as solver_trainer, row_weighted_adam

In [2]:
#os.environ["XLA_FLAGS"]="--xla_force_host_platform_device_count=2"
os.environ["EPATH_USE_TF"] = "false"

In [3]:
proj_root = Path("..").resolve()
data_root = proj_root / "data"
model_dir = data_root / "models"

In [4]:
with zipfile.ZipFile(data_root/"external/arc-prize-2025.zip","r") as zfh:
    with zfh.open("arc-agi_test_challenges.json","r") as fh:
        raw_challenge_data = json.load(fh)
    with zfh.open("sample_submission.json","r") as fh:
        sample_submission = json.load(fh)


In [5]:
max_size = np.r_[0, 0]
datasets = dict(train=[], test=[])
for k, v in raw_challenge_data.items():
    for typ,dst in datasets.items():
        if typ=="train" and any(np.array(iop["input"]).shape != np.array(iop["output"]).shape for iop in v[typ]):
            continue
        for i, iop in enumerate(v[typ]):
            for kk in ["input","output"]:
                if typ=="test" and kk=="output":
                    continue
                img = dataset.Image(np.array(iop[kk],"i1"))
                ex = dataset.ImageExample(
                    challenge=k,
                    example_idx=i,
                    example_type=kk,
                    image=img,
                )
                sh = np.array(img.shape)
                max_size = np.maximum(max_size, sh)
                dst.append(ex)
challenges = frozenset(raw_challenge_data)
challenge_order = tuple(sorted(challenges))
datasets = SimpleNamespace(**{k:dataset.ImagesDataset(
    examples=tuple(v),
    challenges=challenges,
    max_size=tuple(int(v) for v in max_size),
) for k,v in datasets.items()})
max_size

array([30, 30])

In [6]:
num_solution_attempts = 16

In [7]:
solver = arc_solver.ARCSolver(
    **arc_solver.configs["small"],
    num_latent_programs = len(challenge_order)*num_solution_attempts,
    dtype=jnp.float32,
    rngs=nnx.Rngs(42),
)

In [8]:
import arc25.training.cli

In [9]:
chkp_path = etils.epath.Path(
    "gs://576e2361-arc-agi-2/checkpoints/20251030-1638-vertex-ai-arc-solver-small-4xv6e/"
    "20251030-1638-vertex-ai-arc-solver-small-4xv6e-chkp-000256.msgpack.xz"
)
checkpoint_data = saving.load_model(chkp_path)
solver_checkpoint = checkpoint_data.state.model
del solver_checkpoint["latent_program_embeddings"]
nnx.update(solver, solver_checkpoint)


In [10]:
num_devices = jax.local_device_count()

bucket_cuts = [12,20,30]
bucket_shapes = tuple(sorted(set(
   itertools.product(bucket_cuts,bucket_cuts) 
)))


In [11]:
config = solver_trainer.ArcSolverConfig(
    max_num_epochs = None,
    max_num_ref_batches = 1,
    eval_batch_size=64,
)

trainer = solver_trainer.ArcSolverTrainer.make(
    config=config,
    model=solver,
    collator=None,
    inputs_src=SimpleNamespace(challenges=challenge_order, bucket_shapes=bucket_shapes),
    num_devices=num_devices,
    rngs=None,#nnx.Rngs(config.seed),
#    lr_schedule=lr_schedule,
    eval_dataset=datasets.train,
    with_progress_bars=True,
)

In [12]:
trainer._cache_embeddings()

  0%|          | 0/12 [00:00<?, ?it/s]

Tracing embed_inputs for shape dict(input_sizes=(1,7,2), inputs=(1,7,12,12)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(1,64,2), inputs=(1,64,12,12)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(1,39,2), inputs=(1,39,12,20)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(1,8,2), inputs=(1,8,12,30)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(1,22,2), inputs=(1,22,20,20)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(1,64,2), inputs=(1,64,20,20)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(1,12,2), inputs=(1,12,20,30)) (kw=dict(mode='flat', remat=True, unroll=No

In [13]:
def loss_fn(model, minibatch):
    nsa = num_solution_attempts

    mbd = jax.tree.map(
        lambda a: jnp.tile(a[:, None, ...], (1, nsa) + (1,) * (a.ndim - 1)),
        minibatch,
    )

    embeddings = mbd["embeddings"]
    output_sizes = mbd["output_sizes"]
    latent_program_idx = mbd["latent_program_idx"] * nsa
    latent_program_idx += np.arange(nsa)[
        None, :, *(None,) * (latent_program_idx.ndim - 2)
    ]

    logits = model.decode(
        embeddings,
        output_size=output_sizes,
        latent_program_idx=latent_program_idx,
        mode="flat",
        remat=True,
        unroll=None,
        deterministic=True,
    ).astype(jnp.float32)
            
    outputs = mbd["outputs"]
    output_masks = mbd["output_masks"]
    cell_weights = mbd["cell_weight"]

    # Loss on ALL output cells (not masked like MAE)
    cell_crossentropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=outputs, axis=-1
    )

    # Mask to valid output regions and weight by pre-normalized cell weights
    pair_crossentropy = jnp.where(
        output_masks, cell_crossentropy * cell_weights, 0
    ).sum(axis=(-2, -1))
    loss = pair_crossentropy.sum()

    # Per-cell accuracy
    predictions = jnp.argmax(logits, axis=-1)
    cell_correct = predictions == outputs
    cell_accuracy = (
        jnp.where(cell_correct & output_masks, cell_weights, 0)
        .astype(jnp.float32)
        .sum(axis=(-2, -1))
    )

    # Per-pair accuracy (all cells in output must be correct)
    pair_accuracy = (
        (
            # Padding doesn't count against accuracy
            cell_correct
            | ~output_masks
        )
        .all(axis=(-2, -1))
        .astype(jnp.float32)
    )
 

    stats = dict(
        pair_crossentropy = pair_crossentropy,
        cell_accuracy = cell_accuracy,
        pair_accuracy = pair_accuracy,        
    )

    pcs = {k[5:]:v for k,v in stats.items() if k.startswith("pair_")}
    per_example_stats = jnp.stack(list(pcs.values()), axis=-1)

    K = model.latent_program_embeddings.shape[0]
    N = per_example_stats.shape[-1]
    per_task = (
        jnp.zeros((K, N), dtype=per_example_stats.dtype)
        .at[latent_program_idx]
        .add(per_example_stats)
    ).reshape(-1, nsa, N)

    stats = {k: v.sum() for k, v in stats.items()} | dict(
        per_task={k: per_task[..., i] for i, k in enumerate(pcs)},
    )
    
    return loss, stats

In [14]:
train_filter = nnx.PathContains("latent_program_embeddings")

@nnx.jit
def train_step(model, minibatch):
    grad_fn = nnx.value_and_grad(
        loss_fn,
        argnums=nnx.DiffState(0, train_filter),
        has_aux=True,
    )
    (_, stats), grads = grad_fn(model, minibatch)
    return grads, stats


In [15]:
def accumulate_grads(trainer):
    self = trainer
    eval_data = self._eval_data_cache

    mesh = jax.make_mesh(
        (self.num_devices,),
        ("batch",),
        axis_types=(jax.sharding.AxisType.Auto,),
    )

    def reshard(a, *args):
        return jax.device_put(
            a,
            jax.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)),
        )

    model_graph, model_state = nnx.split(self.train_state.model)
    model_state = jax.tree.map(lambda a: reshard(a), model_state)
    resharded_model = nnx.merge(model_graph, model_state)

    res = None
    with contextlib.ExitStack() as stack:
        stack.enter_context(jax.set_mesh(mesh))
        if self.with_progress_bars:
            import tqdm.auto

            pbar = stack.enter_context(
                tqdm.auto.tqdm(total=len(eval_data.minibatches), leave=False)
            )
        else:
            pbar = None

        for minibatch in eval_data.minibatches:
            step_res = train_step(
                resharded_model,
                minibatch,
            )
            if res is None:
                res = step_res
            else:
                res = jax.tree.map(lambda a, b: a + b, res, step_res)

            if pbar is not None:
                pbar.update()

    grads, stats = res

    per_task = stats.pop("per_task")
    shape = (resharded_model.latent_program_tokens.shape[0]//num_solution_attempts, num_solution_attemtps)
    per_task = {
        k: np.asarray(v)
        / np.maximum(1, eval_data.per_class_total_weight).reshape(*shape)
        for k, v in per_task.items()
    }
    stats = {k: float(v) / max(1, eval_data.total_weight) for k, v in res.items()}
    stats.update({
        f"best_{k}":v.max(1).mean()
        for k,v in per_task.items()
    })
    return grads, stats

In [16]:
tx = optax.adam(1e-2,b1=0.9,b2=0.99)
optimizer = nnx.Optimizer(
    trainer.train_state.model,
    tx,
    wrt=train_filter,
)

for step in (pbar:=tqdm.auto.trange(0)):
    grads,stats = accumulate_grads(trainer)
    optimizer.update(trainer.train_state.model, grads)
    print(stats)
    asdfasd

0it [00:00, ?it/s]