In [1]:
import typing
import lzma
import os
import dataclasses
import itertools
import functools
from importlib import reload
from pathlib import Path
from types import MappingProxyType, SimpleNamespace

import cbor2
import attrs
import tqdm.auto
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from arc25 import symmetry, tools as arc25_tools
from arc25.symmetry import D4, transform_vector
from arc25 import serialisation
from arc25.dsl.types import Vector, Dir4
from arc25.vision2.symrep import SymDecompBase, SplitSymDecomp, SymDecompDims, standard_rep, RepSpec
from arc25.vision2.fields import FieldDims, CoordinateGrid
from arc25.vision2.linear import SpaceSymmetricLinear, SpaceSymmetricTensor, SymmetryMappingSpec, SymDecompLinear
from arc25.vision2 import fields, attention, encoder, transformer, mae, swiglu, arc_solver
from arc25.training import saving, dataset, mae as mae_trainer, knn_eval, linear_probe, arc_solver as solver_trainer

In [2]:
os.environ["XLA_FLAGS"]="--xla_force_host_platform_device_count=2"
os.environ["EPATH_USE_TF"] = "false"

In [3]:
proj_root = Path("..").resolve()
data_root = proj_root / "data"

In [4]:
src_dataset = dataset.ImagesDataset.load_compressed_cbor(
    data_root/"repack/re-arc.cbor.xz",
    filter=lambda iop,ex:  iop.input.shape == iop.output.shape,
)
print(len(src_dataset.examples))

518152


In [5]:
input_ds, output_ds = [src_dataset.filtered(lambda img: img.example_type==k) for k in ["input","output"]]

In [6]:
challenge_order = tuple(sorted(src_dataset.challenges))

In [18]:
type(src_dataset)

arc25.training.dataset.ImagesDataset

In [7]:
reload(dataset)

small_eval = True

eval_split, train_split = output_ds.split_by_challenge(
    np.random.default_rng(seed=42),
    n_min=2 if small_eval else 100,
)


size_cuts = [12, 21, 30]
#size_cuts = [8,12,16,24,30]
#size_cuts = [30]

eval_ds,training_ds = [dataset.BucketedDataset.make(
    s,
    set(itertools.product(size_cuts, size_cuts)) if s is not {} else [(30,30)],
    challenges=challenge_order,
) for s in [eval_split, train_split]]

for k,v in dict(
    eval_ds=eval_ds,
    training_ds=training_ds,
).items():
    print(f"{k}: { {kk:vv.n_examples for kk,vv in v.buckets.items()} }")


eval_ds: {(30, 30): 69, (21, 30): 134, (21, 21): 87, (12, 12): 39, (12, 21): 107, (12, 30): 90}
training_ds: {(30, 30): 35557, (21, 30): 69656, (21, 21): 39609, (12, 12): 21746, (12, 21): 48107, (12, 30): 43746}


In [8]:
reload(solver_trainer)

config = solver_trainer.ArcSolverConfig(
    seed = 42,
    
    batch_size = 16,
    ref_batch = 16,
    minibatch_size = 16,
    base_cell_cost = 10, 
        
    learning_rate = 1e-5,
    max_num_epochs = 1,
    max_num_ref_batches = 128,

    warmup_steps = 64,
    checkpoint_every_steps = 16,
    
    mode="flat",
    remat=True,
    unroll=None,

    eval_every_ref_batch = 2,
)



In [9]:
num_devices = jax.local_device_count()

minibatch_size_fn = dataset.MinibatchSizeFunction(
    reference_minibatch_size=config.minibatch_size,
    reference_image_size=config.reference_image_size,
    base_cost=config.base_cell_cost,
    granularity=num_devices,  # Ensure divisibility for pmap
)

batch_spec = dataset.BatchSpec(
    target_batch_weight=config.batch_size,
    reference_image_size=config.reference_image_size,
    # each image gets equal weight
    area_weight_exponent=None,
)

collator = dataset.BucketedCollator.make(
    dataset=training_ds,
    batch_spec=batch_spec,
    minibatch_size=minibatch_size_fn,
    seed=42, 
)

In [10]:
input_src = dataset.OnDemandBucketDataset(
    input_ds,
    bucket_shapes = tuple(sorted(training_ds.buckets.keys(),key=lambda sh:(sh[0]*sh[1],abs(sh[0]-sh[1])))),
    challenges = challenge_order,
    weight_fun = lambda area: None,
)

In [11]:
reload(arc_solver)

model_config = {
    k:v
    for k,v in mae.configs["tiny"].items()
    if k not in {"decoder_cell_infusion"}
}
model_config.update(
    num_program_tokens = 4,
    num_latent_programs = len(challenge_order),
)

solver = arc_solver.ARCSolver(
    **model_config,
    dtype=jnp.float32,
    rngs=nnx.Rngs(42),
)
width = solver.hidden_size

In [17]:
import etils.epath
chkp_path = etils.epath.Path(
    "gs://576e2361-arc-agi-2/aiplatform-custom-training-2025-10-23-13:37:52.100/checkpoints/20251023-1137-vertex-ai-mae-tiny-4xL4/20251023-1137-vertex-ai-mae-tiny-4xL4-chkp-007568-final.msgpack.xz"
)

print(f"Loading encoder from checkpoint: {chkp_path}")
with chkp_path.open("rb") as fh:
    encoder_checkpoint = saving.load_model(fh)
nnx.update(solver.encoder, encoder_checkpoint.state.model.encoder)
print("Encoder loaded successfully")

Loading encoder from checkpoint: gs://576e2361-arc-agi-2/aiplatform-custom-training-2025-10-23-13:37:52.100/checkpoints/20251023-1137-vertex-ai-mae-tiny-4xL4/20251023-1137-vertex-ai-mae-tiny-4xL4-chkp-007568-final.msgpack.xz
Encoder loaded successfully


In [None]:
trainer = solver_trainer.ArcSolverTrainer.make(
    config=config,
    model=solver,
    collator=collator,
    inputs_src=input_src,
    num_devices = num_devices,
    rngs = nnx.Rngs(42),
)


In [None]:
stats = trainer.run_main()