In [1]:
using Pkg
Pkg.activate("..")  #one level up, where Project.toml lives
Pkg.instantiate()   #download/install anything missing
# Pkg.status();

[32m[1m  Activating[22m[39m project at `~/Documents/repos/JuliaExploreHRM`


In [2]:
using StatsBase
using Random, Statistics
using Flux
using Flux, Zygote, Optimisers
using Flux: onehotbatch, onecold
using DataFrames, Plots, CSV
using Test
using CSV, DataFrames
using Measures

using NNlib: gelu
using PositionalEmbeddings     # Absolute positional encodings


In [3]:
include(joinpath(@__DIR__, "..", "data", "long_addition_carries_gen.jl"))
using .AdditionCarryData


In [4]:
include(joinpath(@__DIR__, "..", "data", "hrm_common_addition_FLUX.jl"))
using .HRMFluxAddCarry

In [5]:
# pure length shift like your Dyck/Boolean splits (2-4 -> 5 -> 6-7)
splits = AdditionCarryData.make_addition_sum_splits(
    n_train=100_000, n_val=10_000, n_test_id=10_000, n_test_mid=10_000, n_test_ood=10_000,
    train_digits_a=2:4, train_digits_b=2:4,
    mid_digits_a=5:5,   mid_digits_b=5:5,
    ood_digits_a=6:7,   ood_digits_b=6:7,
    base=10, lsd_first=true, seed=1234,
    require_carry=nothing,     # leave unconstrained
    min_carry_chain_id=0,      # ID
    min_carry_chain_mid=0,     # MID
    min_carry_chain_ood=0      # OOD
)

AdditionCarryData.write_addition_sum_splits("datasets/addition_sum", splits; base=10, include_vocab=true)

In [8]:
"""
    make_verification_splits(; kwargs...) -> NamedTuple

Builds five splits for addition-with-carry verification (binary labels):
  - train_A: shorter regime (analogue of your Phase-A)
  - train_B: longer regime  (analogue of your Phase-B)
  - val_id  : ID validation   (same regime as A)
  - test_mid: slightly longer
  - test_ood: much longer
"""
function make_verification_splits(;  # sizes
    n_train_A::Int = 100_000,
    n_train_B::Int = 100_000,
    n_val::Int     = 10_000,
    n_test_mid::Int= 10_000,
    n_test_ood::Int= 10_000,

    # length regimes for A and B operands (MSD lengths)
    trainA_digits_a::UnitRange{Int} = 2:4,
    trainA_digits_b::UnitRange{Int} = 2:4,
    trainB_digits_a::UnitRange{Int} = 2:6,
    trainB_digits_b::UnitRange{Int} = 2:6,
    mid_digits_a::UnitRange{Int}    = 5:5,
    mid_digits_b::UnitRange{Int}    = 5:5,
    ood_digits_a::UnitRange{Int}    = 6:7,
    ood_digits_b::UnitRange{Int}    = 6:7,

    # formatting and arithmetic knobs
    base::Int = 10,
    lsd_first::Bool = true,
    allow_leading_zero::Bool = false,
    delimiter::String = " ",
    include_plus::Bool = true,
    include_equals::Bool = true,

    # carry constraints (optional)
    require_carry::Union{Bool,Nothing} = nothing,
    min_carry_chain_A::Int  = 0,
    min_carry_chain_B::Int  = 0,
    min_carry_chain_mid::Int= 0,
    min_carry_chain_ood::Int= 0,

    # class balance (true equations vs false)
    positive_fraction::Float64 = 0.5,

    # reproducibility
    seed::Int = 1234
)
    gen = AdditionCarryData.generate_addition_verification_dataset

    train_A = gen(n_train_A;
        digits_a=trainA_digits_a, digits_b=trainA_digits_b, base=base,
        positive_fraction=positive_fraction, random_seed=seed,
        lsd_first=lsd_first, allow_leading_zero=allow_leading_zero,
        require_carry=require_carry, min_carry_chain=min_carry_chain_A,
        delimiter=delimiter, include_plus=include_plus, include_equals=include_equals)

    train_B = gen(n_train_B;
        digits_a=trainB_digits_a, digits_b=trainB_digits_b, base=base,
        positive_fraction=positive_fraction, random_seed=seed+1,
        lsd_first=lsd_first, allow_leading_zero=allow_leading_zero,
        require_carry=require_carry, min_carry_chain=min_carry_chain_B,
        delimiter=delimiter, include_plus=include_plus, include_equals=include_equals)

    val_id = gen(n_val;
        digits_a=trainA_digits_a, digits_b=trainA_digits_b, base=base,
        positive_fraction=positive_fraction, random_seed=seed+2,
        lsd_first=lsd_first, allow_leading_zero=allow_leading_zero,
        require_carry=require_carry, min_carry_chain=min_carry_chain_A,
        delimiter=delimiter, include_plus=include_plus, include_equals=include_equals)

    test_mid = gen(n_test_mid;
        digits_a=mid_digits_a, digits_b=mid_digits_b, base=base,
        positive_fraction=positive_fraction, random_seed=seed+3,
        lsd_first=lsd_first, allow_leading_zero=allow_leading_zero,
        require_carry=require_carry, min_carry_chain=min_carry_chain_mid,
        delimiter=delimiter, include_plus=include_plus, include_equals=include_equals)

    test_ood = gen(n_test_ood;
        digits_a=ood_digits_a, digits_b=ood_digits_b, base=base,
        positive_fraction=positive_fraction, random_seed=seed+4,
        lsd_first=lsd_first, allow_leading_zero=allow_leading_zero,
        require_carry=require_carry, min_carry_chain=min_carry_chain_ood,
        delimiter=delimiter, include_plus=include_plus, include_equals=include_equals)

    # Convert to (strings, labels::Int)
    to_xy(ds) = begin
        xs = [r.input for r in ds]
        ys = [r.label ? 1 : 0 for r in ds]
        xs, ys
    end

    x_tr_A,  y_tr_A  = to_xy(train_A)
    x_tr_B,  y_tr_B  = to_xy(train_B)
    x_val,   y_val   = to_xy(val_id)
    x_mid,   y_mid   = to_xy(test_mid)
    x_ood,   y_ood   = to_xy(test_ood)

    return (trainA_x=x_tr_A, trainA_y=y_tr_A,
            trainB_x=x_tr_B, trainB_y=y_tr_B,
            val_x=x_val, val_y=y_val,
            mid_x=x_mid, mid_y=y_mid,
            ood_x=x_ood, ood_y=y_ood)
end


make_verification_splits

In [9]:
"Digits 0..(base-1), then '+', '=', and '<pad>'."
build_vocab_with_pad(base::Int) =
    AdditionCarryData.build_vocabulary(base; include_plus=true, include_equals=true,
                                       include_pad=true, include_eos=false)

"""
batch_tokenize_pad(strings, vocab; pad_id) -> (L_max, B) Int matrix.
Whitespace-split, map with `vocab`, then right-pad with `pad_id`.
"""
function batch_tokenize_pad(strings::Vector{String},
                            vocab::Dict{String,Int};
                            pad_id::Int)
    B = length(strings)
    token_lists = Vector{Vector{Int}}(undef, B)
    Lmax = 0
    @inbounds for i in 1:B
        ids = AdditionCarryData.tokenize_with_vocabulary(strings[i], vocab)
        token_lists[i] = ids
        Lmax = max(Lmax, length(ids))
    end
    X = fill(pad_id, Lmax, B)
    @inbounds for i in 1:B
        ids = token_lists[i]
        if !isempty(ids)
            X[1:length(ids), i] = ids
        end
    end
    X
end


batch_tokenize_pad

In [10]:
# ---- Splits ----
splits = make_verification_splits(
    n_train_A=100_000, n_train_B=100_000,
    n_val=10_000, n_test_mid=10_000, n_test_ood=10_000,
    trainA_digits_a=2:4, trainA_digits_b=2:4,
    trainB_digits_a=2:6, trainB_digits_b=2:6,
    mid_digits_a=5:5, mid_digits_b=5:5,
    ood_digits_a=6:7, ood_digits_b=6:7,
    base=10, lsd_first=true, seed=1234,
    require_carry=nothing,
    min_carry_chain_A=0, min_carry_chain_B=0,
    min_carry_chain_mid=0, min_carry_chain_ood=0,
    positive_fraction=0.5
)

# Assign in variables that mirror the workflow
x_tr_A,  y_tr_A  = splits.trainA_x, splits.trainA_y
x_tr_B,  y_tr_B  = splits.trainB_x, splits.trainB_y
x_te_ID, y_te_ID = splits.val_x,    splits.val_y   # ID validation
x_te_M,  y_te_M  = splits.mid_x,    splits.mid_y   # MID
x_te_O,  y_te_O  = splits.ood_x,    splits.ood_y   # OOD

# Keep references for any reports you may add later
xs_ID, xs_MID, xs_OOD = x_te_ID, x_te_M, x_te_O

# ---- Vocabulary and tokenization ----
vocab  = build_vocab_with_pad(10)
pad_id = vocab["<pad>"]

to_ids(xs::Vector{String}) = batch_tokenize_pad(xs, vocab; pad_id=pad_id)

Xids_tr_A  = to_ids(x_tr_A)
Xids_tr_B  = to_ids(x_tr_B)
Xids_te_ID = to_ids(x_te_ID)
Xids_te_M  = to_ids(x_te_M)
Xids_te_O  = to_ids(x_te_O)

# ---- Positional bound (add 1 for CLS in H blocks) ----
POS_L_MAX = maximum((size(Xids_tr_A,1), size(Xids_tr_B,1),
                     size(Xids_te_ID,1), size(Xids_te_M,1), size(Xids_te_O,1))) + 1

# ---- HRM config (binary verification: single logit) ----
cfg = (
    d_in   = 0,            # unused in token path
    d_hid  = 96,
    d_out  = 1,            # single logit (use σ + BCE)
    N      = 3,            # outer HRM cycles
    T      = POS_L_MAX,    # PE bound
    batch  = 32,
    lr     = 5e-5,
    num_tokens = length(vocab),   # 10 digits + '+' + '=' + '<pad>' = 13
    d_embed    = 128,
    l_heads    = 4,  l_ff_mult = 6,
    h_heads    = 4,  h_ff_mult = 6,
    dropout    = 0.1,
    pad_id     = pad_id
)

# ---- Build both architectures ----
models_HL = HRMFluxAddCarry.build_models_addcarry(cfg;
    arch = :HL,
    l_positional_encoding_kind = :none,
    h_positional_encoding_kind = :sinusoidal,
    pos_L_max = POS_L_MAX
)

models_HH = HRMFluxAddCarry.build_models_addcarry(cfg;
    arch = :HH,
    l_positional_encoding_kind = :sinusoidal,
    h_positional_encoding_kind = :sinusoidal,
    pos_L_max = POS_L_MAX
)

# ---- Keep PAD embedding neutral to avoid length bias ----
function freeze_pad!(models, pad_id::Int)
    if hasproperty(models, :tok_emb) && models.tok_emb !== nothing
        models.tok_emb.weight[:, pad_id] .= 0f0
    end
end
freeze_pad!(models_HL, pad_id)
freeze_pad!(models_HH, pad_id)


128-element view(::Matrix{Float32}, :, 13) with eltype Float32:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 ⋮
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

In [11]:
# Length-bucketed minibatches (by effective non-pad length)
function each_minibatch_bucketed(X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer};
                                 batch::Int, buckets::Int=6, pad_id::Int,
                                 rng::AbstractRNG=Random.GLOBAL_RNG)
    L, N = size(X)
    lens = fill(L, N)
    @inbounds for i in 1:N, t in 1:L
        if X[t, i] == pad_id
            lens[i] = t - 1
            break
        end
    end
    order  = sortperm(1:N, by=i->lens[i])
    groups = Iterators.partition(order, max(1, ceil(Int, N/buckets)))
    out = Vector{Tuple{Matrix{Int}, Vector{Int}}}()
    for g in groups
        idx = collect(g); Random.shuffle!(rng, idx)
        for k in 1:batch:length(idx)
            sel = idx[k:min(k+batch-1, length(idx))]
            push!(out, (X[:, sel], y[sel]))
        end
    end
    out
end

# Mixed A/B stream with cosine schedule (for curriculum)
function each_minibatch_mixed(XA, yA, XB, yB; batch::Int, pB::Float64, buckets::Int,
                              pad_id::Int, rng::AbstractRNG=Random.GLOBAL_RNG)
    itA = each_minibatch_bucketed(XA, yA; batch=batch, buckets=buckets, pad_id=pad_id, rng=rng)
    itB = each_minibatch_bucketed(XB, yB; batch=batch, buckets=buckets, pad_id=pad_id, rng=rng)
    ia = 1; ib = 1
    out = Vector{Tuple{Matrix{Int}, Vector{Int}}}()
    while ia <= length(itA) || ib <= length(itB)
        useB = rand(rng) < pB
        if useB && ib <= length(itB)
            push!(out, itB[ib]); ib += 1
        elseif ia <= length(itA)
            push!(out, itA[ia]); ia += 1
        elseif ib <= length(itB)
            push!(out, itB[ib]); ib += 1
        end
    end
    out
end

# Loss and metrics (single-logit BCE)
logit_bce(ŷ, y::Vector{Int}) = Flux.Losses.logitbinarycrossentropy(vec(ŷ), Float32.(y))
accuracy(ŷ, y::Vector{Int})  = mean((σ.(vec(ŷ)) .>= 0.5) .== (y .== 1))

# Batch loss wrapper using your HRM runner
function batch_loss(models, x_batch::AbstractMatrix{<:Integer}, y_batch::AbstractVector{<:Integer}, cfg)
    B = size(x_batch, 2)
    low_state, high_state = HRMFluxAddCarry.init_states(B, cfg.d_hid)
    ŷ, _, _ = HRMFluxAddCarry.run_sequence_addcarry!(models, x_batch, low_state, high_state; N=cfg.N, cfg=cfg)
    logit_bce(ŷ, y_batch)
end

# Batched accuracy evaluation
function evaluate_matrix(models, X::AbstractMatrix{<:Integer}, y::Vector{Int}, cfg; batch_size::Int=256)
    total_acc, total_n = 0.0, 0
    N = size(X, 2)
    for k in 1:batch_size:N
        sel = k:min(k+batch_size-1, N)
        xb = X[:, sel]; yb = y[sel]
        B  = size(xb, 2)
        low_state, high_state = HRMFluxAddCarry.init_states(B, cfg.d_hid)
        ŷ, _, _ = HRMFluxAddCarry.run_sequence_addcarry!(models, xb, low_state, high_state; N=cfg.N, cfg=cfg)
        total_acc += mean((σ.(vec(ŷ)) .>= 0.5) .== (yb .== 1)) * B
        total_n   += B
    end
    total_acc / total_n
end


evaluate_matrix (generic function with 1 method)

In [12]:
# Cosine schedule for mixing Phase-B into the stream
const MIN_MIX_IN_B = 0.10
const MAX_MIX_IN_B = 0.60
cosine01(t) = 0.5 * (1 .- cos(pi * clamp(t, 0, 1)))

mutable struct TrainLog
    name::String
    epoch::Vector{Int}
    pB::Vector{Float64}
    loss::Vector{Float64}
    acc_id::Vector{Float64}
    acc_mid::Vector{Float64}
    acc_ood::Vector{Float64}
end
newlog(name::String) = TrainLog(name, Int[], Float64[], Float64[], Float64[], Float64[], Float64[])
function push_epoch!(L::TrainLog; epoch, pB, loss, acc_id, acc_mid, acc_ood)
    push!(L.epoch, epoch); push!(L.pB, pB); push!(L.loss, loss)
    push!(L.acc_id, acc_id); push!(L.acc_mid, acc_mid); push!(L.acc_ood, acc_ood)
end

function train!(name::String, models, Xids_tr_A, y_tr_A, Xids_tr_B, y_tr_B,
                Xids_te_ID, y_te_ID, Xids_te_M, y_te_M, Xids_te_O, y_te_O,
                cfg;
                epochs_A::Int, epochs_B::Int, rng::AbstractRNG)

    log = newlog(name)

    # Optimizer (AdamW if available, else Adam) + gradient clipping
    if isdefined(Optimisers, :AdamW)
        opt = Optimisers.OptimiserChain(
            Optimisers.ClipNorm(1.0),
            Optimisers.AdamW(cfg.lr, (0.9, 0.999), 1e-4),
        )
    else
        opt = Optimisers.OptimiserChain(
            Optimisers.ClipNorm(1.0),
            Optimisers.Adam(cfg.lr),
        )
    end
    opt_state = Optimisers.setup(opt, models)

    for epoch in 1:(epochs_A + epochs_B)
        pB = (epoch <= epochs_A) ? 0.0 :
             (MIN_MIX_IN_B + (MAX_MIX_IN_B - MIN_MIX_IN_B) * cosine01((epoch - epochs_A) / epochs_B))

        total_loss = 0.0; batches = 0

        for (xb, yb) in each_minibatch_mixed(Xids_tr_A, y_tr_A, Xids_tr_B, y_tr_B;
                                             batch=cfg.batch, pB=pB, buckets=6,
                                             pad_id=cfg.pad_id, rng=rng)
            # keep PAD embedding column neutral
            freeze_pad!(models, cfg.pad_id)

            L, back = Zygote.pullback(m -> batch_loss(m, xb, yb, cfg), models)
            grads = back(one(L))[1]
            opt_state, models = Optimisers.update(opt_state, models, grads)

            # re-zero PAD column after update
            freeze_pad!(models, cfg.pad_id)

            total_loss += Float64(L); batches += 1
        end

        # Evaluate
        acc_id  = evaluate_matrix(models, Xids_te_ID, y_te_ID, cfg)
        acc_mid = evaluate_matrix(models, Xids_te_M,  y_te_M,  cfg)
        acc_ood = evaluate_matrix(models, Xids_te_O,  y_te_O,  cfg)

        @info "$(name) epoch $(epoch)  loss $(round(total_loss/max(batches,1), digits=4))  " *
              "ID $(round(acc_id,digits=3))  MID $(round(acc_mid,digits=3))  OOD $(round(acc_ood,digits=3))"

        push_epoch!(log; epoch=epoch, pB=pB, loss=total_loss/max(batches,1),
                    acc_id=acc_id, acc_mid=acc_mid, acc_ood=acc_ood)
    end

    return models, log
end

# ---- Run trainings (analogue to your Dyck workflow) ----
epochs_A = 10
epochs_B = 20

rng_HL = MersenneTwister(13)
rng_HH = MersenneTwister(13)

models_HL, log_HL = train!("H+L", models_HL,
    Xids_tr_A, y_tr_A, Xids_tr_B, y_tr_B,
    Xids_te_ID, y_te_ID, Xids_te_M, y_te_M, Xids_te_O, y_te_O,
    cfg; epochs_A=epochs_A, epochs_B=epochs_B, rng=rng_HL)

models_HH, log_HH = train!("H+H", models_HH,
    Xids_tr_A, y_tr_A, Xids_tr_B, y_tr_B,
    Xids_te_ID, y_te_ID, Xids_te_M, y_te_M, Xids_te_O, y_te_O,
    cfg; epochs_A=epochs_A, epochs_B=epochs_B, rng=rng_HH)


┌ Info: H+L epoch 1  loss 0.0283  ID 0.5  MID 0.5  OOD 0.5
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/07_addition_carry_FLUX/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X16sZmlsZQ==.jl:69
┌ Info: H+L epoch 2  loss 0.0345  ID 0.5  MID 0.5  OOD 0.5
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/07_addition_carry_FLUX/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X16sZmlsZQ==.jl:69


InterruptException: InterruptException:

In [None]:
# ] add ProgressMeter
using ProgressMeter

# Replace inner 'for (xb, yb) in ...' loop with:
p = Progress(nb_est, 1, show_speed=true)
for (xb, yb) in each_minibatch_mixed(...)
    # ... same training step body ...
    next!(p; showvalues=[(:loss, total_loss/max(batches,1))])
end
