In [1]:
import typing
import lzma
import os
import dataclasses
import itertools
import functools
from importlib import reload
from pathlib import Path
from types import MappingProxyType, SimpleNamespace

import cbor2
import attrs
import tqdm.auto
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import jaxtyping as jt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from arc25 import symmetry, tools as arc25_tools
from arc25.symmetry import D4, transform_vector
from arc25 import serialisation
from arc25.dsl.types import Vector, Dir4
from arc25.vision2.symrep import SymDecompBase, SplitSymDecomp, SymDecompDims, standard_rep, RepSpec
from arc25.vision2.fields import FieldDims, CoordinateGrid
from arc25.vision2.linear import SpaceSymmetricLinear, SpaceSymmetricTensor, SymmetryMappingSpec, SymDecompLinear
from arc25.vision2 import fields, attention, encoder, transformer, mae, swiglu, arc_solver
from arc25.training import saving, dataset, mae as mae_trainer, knn_eval, linear_probe, arc_solver as solver_trainer, row_weighted_adam

In [2]:
os.environ["XLA_FLAGS"]="--xla_force_host_platform_device_count=2"
os.environ["EPATH_USE_TF"] = "false"

In [3]:
proj_root = Path("..").resolve()
data_root = proj_root / "data"
model_dir = data_root / "models"

In [4]:
src_dataset = dataset.ImagesDataset.load_compressed_cbor(
    data_root/"repack/re-arc.cbor.xz",
    filter=lambda iop,ex:  iop.input.shape == iop.output.shape,
)
challenge_order = tuple(sorted(src_dataset.challenges))
print(len(src_dataset.examples))

518152


In [5]:
reload(dataset)

small_eval = True

eval_split, train_split = src_dataset.split_by_challenge(
    np.random.default_rng(seed=42),
    n_min=6 if small_eval else 100,
)

input_ds, output_ds = train_split.split_input_output()

size_cuts = [12, 21, 30]
#size_cuts = [8,12,16,24,30]
#size_cuts = [30]

training_ds, = [dataset.BucketedDataset.make(
    s,
    set(itertools.product(size_cuts, size_cuts)) if s is not {} else [(30,30)],
    challenges=challenge_order,
) for s in [output_ds]]

for k,v in dict(
    training_ds=training_ds,
).items():
    print(f"{k}: { {kk:vv.n_examples for kk,vv in v.buckets.items()} }")


training_ds: {(30, 30): 35519, (21, 30): 69584, (21, 21): 39576, (12, 12): 21717, (12, 21): 48056, (12, 30): 43704}


In [6]:
reload(solver_trainer)

config = solver_trainer.ArcSolverConfig(
    seed = 42,
    
    batch_size = 16,
    ref_batch = 16,
    minibatch_size = 16,
    base_cell_cost = 10, 
        
    learning_rate = 1e-5,
    max_num_epochs = 1,
    max_num_ref_batches = 128,

    warmup_steps = 64,
    checkpoint_every_steps = 16,
    
    mode="flat",
    remat=True,
    unroll=None,

    eval_every_ref_batch = 0.5,
)



In [7]:
num_devices = jax.local_device_count()

minibatch_size_fn = dataset.MinibatchSizeFunction(
    reference_minibatch_size=config.minibatch_size,
    reference_image_size=config.reference_image_size,
    base_cost=config.base_cell_cost,
    granularity=num_devices,  # Ensure divisibility for pmap
)

batch_spec = dataset.BatchSpec(
    target_batch_weight=config.batch_size,
    reference_image_size=config.reference_image_size,
    # each image gets equal weight
    area_weight_exponent=None,
)

collator = dataset.BucketedCollator.make(
    dataset=training_ds,
    batch_spec=batch_spec,
    minibatch_size=minibatch_size_fn,
    seed=42, 
)

In [8]:
input_src = dataset.OnDemandBucketDataset(
    input_ds,
    bucket_shapes = tuple(sorted(training_ds.buckets.keys(),key=lambda sh:(sh[0]*sh[1],abs(sh[0]-sh[1])))),
    challenges = challenge_order,
    weight_fun = lambda area: None,
)

In [9]:
reload(arc_solver)

model_config = {
    k:v
    for k,v in mae.configs["tiny"].items()
    if k not in {"decoder_cell_infusion"}
}
model_config.update(
    num_program_tokens = 4,
    num_latent_programs = len(challenge_order),
)

solver = arc_solver.ARCSolver(
    **model_config,
    dtype=jnp.float32,
    rngs=nnx.Rngs(42),
)
width = solver.hidden_size

In [10]:
chkp_path = model_dir / "20251023-1137-vertex-ai-mae-tiny-4xL4-chkp-006912.msgpack.xz"
print(f"Loading encoder from checkpoint: {chkp_path}")
encoder_checkpoint = saving.load_model(chkp_path)
nnx.update(solver.encoder, encoder_checkpoint.state.model.encoder)
print("Encoder loaded successfully")

Loading encoder from checkpoint: /Users/yves/git-private/arc-2025/data/models/20251023-1137-vertex-ai-mae-tiny-4xL4-chkp-006912.msgpack.xz
Encoder loaded successfully


In [11]:
reload(row_weighted_adam)
reload(solver_trainer)

<module 'arc25.training.arc_solver' from '/Users/yves/git-private/arc-2025/src/arc25/training/arc_solver.py'>

In [12]:
trainer = solver_trainer.ArcSolverTrainer.make(
    config=config,
    model=solver,
    collator=collator,
    inputs_src=input_src,
    eval_dataset = eval_split,
    num_devices = num_devices,
    rngs = nnx.Rngs(42),
    with_progress_bars = True,
)


In [None]:
stats = trainer.run_main()

--- ArcSolverTrainer ---
Run: 20251029-1302-ArcSolverTrainer
Devices: 2 × cpu
Training batch data weight: 16 (1 optimizer step)
Reference step data weight: 16 (~1.00 optimizer steps)
Total steps: 128
Evaluation: every 0.5 reference steps
----------------------------

Starting training...


Training:   0%|          | 0/128 [00:00<?, ?it/s]

Tracing _compute_grads for shape dict(cell_weight=(4,21,21), input_sizes=(4,2), inputs=(4,21,21), latent_program_idx=(4), output_masks=(4,21,21), outputs=(4,21,21)) (kw=dict(mode='flat', remat=True, unroll=None))
Tracing _compute_grads for shape dict(cell_weight=(5,12,30), input_sizes=(5,2), inputs=(5,12,30), latent_program_idx=(5), output_masks=(5,12,30), outputs=(5,12,30)) (kw=dict(mode='flat', remat=True, unroll=None))
Tracing _apply_update
Tracing _compute_grads for shape dict(cell_weight=(5,12,30), input_sizes=(5,2), inputs=(5,12,30), latent_program_idx=(5), output_masks=(5,12,30), outputs=(5,12,30)) (kw=dict(mode='flat', remat=True, unroll=None))
Tracing _compute_grads for shape dict(cell_weight=(2,30,30), input_sizes=(2,2), inputs=(2,30,30), latent_program_idx=(2), output_masks=(2,30,30), outputs=(2,30,30)) (kw=dict(mode='flat', remat=True, unroll=None))
Tracing _apply_update

[Step 1] Preparing input embeddings for evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Tracing embed_inputs for shape dict(input_sizes=(2,32,2), inputs=(2,32,12,12)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(2,79,2), inputs=(2,79,12,21)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(2,66,2), inputs=(2,66,12,30)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(2,60,2), inputs=(2,60,21,21)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(2,103,2), inputs=(2,103,21,30)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing embed_inputs for shape dict(input_sizes=(2,53,2), inputs=(2,53,30,30)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Embedding inputs for evaluation completed in 111.5s

[Step 1] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Tracing evaluate for shape dict(cell_weight=(2,32,12,12), embeddings=(2, 32), latent_program_idx=(2,32), output_masks=(2,32,12,12), output_sizes=(2,32,2), outputs=(2,32,12,12)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing evaluate for shape dict(cell_weight=(2,79,12,21), embeddings=(2, 79), latent_program_idx=(2,79), output_masks=(2,79,12,21), output_sizes=(2,79,2), outputs=(2,79,12,21)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing evaluate for shape dict(cell_weight=(2,66,12,30), embeddings=(2, 66), latent_program_idx=(2,66), output_masks=(2,66,12,30), output_sizes=(2,66,2), outputs=(2,66,12,30)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing evaluate for shape dict(cell_weight=(2,60,21,21), embeddings=(2, 60), latent_program_idx=(2,60), output_masks=(2,60,21,21), output_sizes=(2,60,2), outputs=(2,60,21,21)) (kw=dict(mode='flat', remat=True, unroll=None, deterministic=True))
Tracing evaluate for

  np.asarray(per_class.pop("loss")) / per_class_loss_weight,


Evaluation completed in 47.1s: cell_accuracy: 0.724 pair_accuracy: 0.004 pair_crossentropy: 2.060 loss: 2.055 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]
Tracing _compute_grads for shape dict(cell_weight=(4,21,21), input_sizes=(4,2), inputs=(4,21,21), latent_program_idx=(4), output_masks=(4,21,21), outputs=(4,21,21)) (kw=dict(mode='flat', remat=True, unroll=None))

[Step 2] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Evaluation completed in 32.6s: cell_accuracy: 0.724 pair_accuracy: 0.004 pair_crossentropy: 2.059 loss: 2.050 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]
Tracing _compute_grads for shape dict(cell_weight=(2,21,30), input_sizes=(2,2), inputs=(2,21,30), latent_program_idx=(2), output_masks=(2,21,30), outputs=(2,21,30)) (kw=dict(mode='flat', remat=True, unroll=None))

[Step 3] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Evaluation completed in 31.2s: cell_accuracy: 0.724 pair_accuracy: 0.004 pair_crossentropy: 2.057 loss: 2.045 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]

[Step 4] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Evaluation completed in 29.9s: cell_accuracy: 0.725 pair_accuracy: 0.004 pair_crossentropy: 2.055 loss: 2.040 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]
Tracing _compute_grads for shape dict(cell_weight=(12,12,12), input_sizes=(12,2), inputs=(12,12,12), latent_program_idx=(12), output_masks=(12,12,12), outputs=(12,12,12)) (kw=dict(mode='flat', remat=True, unroll=None))

[Step 5] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Evaluation completed in 30.3s: cell_accuracy: 0.725 pair_accuracy: 0.004 pair_crossentropy: 2.052 loss: 2.034 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]

[Step 6] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Evaluation completed in 35.2s: cell_accuracy: 0.725 pair_accuracy: 0.004 pair_crossentropy: 2.049 loss: 2.027 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]
Tracing _compute_grads for shape dict(cell_weight=(7,12,21), input_sizes=(7,2), inputs=(7,12,21), latent_program_idx=(7), output_masks=(7,12,21), outputs=(7,12,21)) (kw=dict(mode='flat', remat=True, unroll=None))

[Step 7] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

Evaluation completed in 31.1s: cell_accuracy: 0.726 pair_accuracy: 0.004 pair_crossentropy: 2.044 loss: 2.019 class_accuracy_histogram: [397,0,0,2,0,1,0,0,0,0]

[Step 8] Running evaluation...


  0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
df = pd.DataFrame([s["debug"] for s in stats])
df.head()

In [None]:
fig, axes = plt.subplots(len(df.columns),1,figsize=(8,8),sharex=True)
for ax,k in zip(axes,df.columns):
    ax.plot(df.loc[:,k])
    ax.annotate(k,(0,0),xycoords="axes fraction")

In [None]:
cache = trainer._eval_data_cache
print(cache.total_weight, cache.per_class_total_weight.sum())
print(cache.per_class_total_weight)
print(sum(w.sum() if (w:=mb.get("weight")) is not None else mb["latent_program_idx"].size for mb in cache.minibatches))


In [None]:
v6e = np.r_[
    30,14, 5.46,14, 5.46,
    24,20, 8.07,23, 7.76,
    20,24, 9.44,29, 9.38,
    16,48,14.03,48,14.03,
    12,80,19.83,80,19.83,
].reshape(-1,5)

L4 = np.r_[
    30, 5, 1.63, 5, 1.63,
#    24, 7, 2.42, 7, 2.42,
    20,16, 4.31,16, 4.31,
#    16,10, 4.64,10, 4.64,
    12,40,11.17,40,11.17,
].reshape(-1,5)



fig, axes = plt.subplots(1,2,figsize=(12,6),sharex=True)
for v,cost,ccb in zip([v6e,L4],[2.97,1.05],[64,0]):
    imsz, fsz, fspd, bsz, bspd = v.T
    ax = axes[0]
    memory = (imsz**2 + ccb) / 15**2
    print(f"At base cost {ccb}, minibatch size should be {(bsz*memory).min():.1f}")
    l, = ax.plot(imsz, bsz*memory,'o-')
    c = l.get_color()
    #memory = 7*imsz*((imsz+7)//8) / 15**2
    #ax.plot(imsz, bsz*memory,'o:',c=c)
    ax = axes[1]
    work = imsz*(imsz-0)
    weight = 0.5*imsz
    fac = work/weight/cost
    ax.plot(imsz, fspd*fac,'o-',c=c)
    ax.plot(imsz, bspd*fac,'o:',c=c)