In [40]:
using Pkg
Pkg.activate("..")  #one level up, where Project.toml lives
Pkg.instantiate()   #download/install anything missing
# Pkg.status();

[32m[1m  Activating[22m[39m project at `~/Documents/repos/JuliaExploreHRM`


In [41]:
include(joinpath(@__DIR__, "..", "data", "nested_boolean_gen.jl"))
include(joinpath(@__DIR__, "..", "data", "hrm_common_nested_boolean_FLUX.jl"))

using .BooleanDataGenerator
using .HRMFlux

using StatsBase
using Random, Statistics
using Flux, Zygote, Optimisers
using Flux: onehotbatch, onecold



In [42]:
# Training data (depth 2-4)
X_train, y_trainainainain, _ = generate_data(100; min_depth=2, max_depth=4)

# Test data (depth 5-8) 
X_test, y_test, _ = generate_data(20; min_depth=5, max_depth=8)

# Test data (held-out NAND)
X_test_ops, y_test_ops, _ = generate_data(20; held_out_ops=[:NAND])

([1 1 … 1 1; 1 0 … 0 0; … ; 0 1 … 1 0; 0 1 … 0 1], [1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1], ["(NOT (AND x1 (XOR x1 x1)))", "(AND x5 x2)", "(NOT x3)", "(OR x2 x4)", "(OR x2 (AND x3 x3))", "(XOR (OR x2 x4) (XOR x3 x2))", "(XOR x2 x5)", "(OR x2 x4)", "(NOT (AND x3 (OR x4 x5)))", "(OR x2 x5)", "(OR x3 (OR x4 x5))", "(AND x4 (AND (XOR x4 x3) x3))", "(OR x1 (OR x1 x3))", "(OR x5 x2)", "(OR x4 x1)", "(AND (AND x2 (XOR x2 x4)) (AND x1 x1))", "(NOT x1)", "(OR x5 (OR x1 x4))", "(XOR x3 (NOT x5))", "(OR (NOT x5) (OR x1 x2))"])

In [43]:
# Generate training data (depth 2-4)
X_train, y_train, expr_train = generate_data(100; min_depth=2, max_depth=4, seed=42)

# Generate test data with depth generalization (depth 5-8)
X_test, y_test, expr_test = generate_data(20; min_depth=5, max_depth=8, seed=123)

# Generate test data with held-out operations (no NAND)
X_test_ops, y_test_ops, expr_test_ops = generate_data(20; held_out_ops=[:NAND], seed=456)

println("Training: X=$(size(X_train)), y=$(size(y_train))")
println("Test (depth): X=$(size(X_test)), y=$(size(y_test))")
println("Test (ops): X=$(size(X_test_ops)), y=$(size(y_test_ops))")

# Show a few examples
println("\nTraining examples:")
for i in 1:3
    println("$(expr_train[i]) | vars=$(X_train[i,:]) → $(y_train[i])")
end


Training: X=(100, 5), y=(100,)
Test (depth): X=(20, 5), y=(20,)
Test (ops): X=(20, 5), y=(20,)

Training examples:
(NOT (AND x1 (NAND x1 x1))) | vars=[1, 1, 1, 1, 1] → 1
(AND x5 x2) | vars=[1, 0, 0, 0, 0] → 0
(NOT x3) | vars=[1, 1, 1, 0, 1] → 0


In [44]:
# Train (depth 2-4), Test-ID (2-4), Test-OOD (5-8)
X_train, y_train, expr_train = BooleanDataGenerator.generate_data(2000; variable_count=6, min_depth=2, max_depth=4, seed=1)
X_id,    y_test_id,    expr_id    = BooleanDataGenerator.generate_data(500;  variable_count=6, min_depth=2, max_depth=4, seed=2)
X_ood,   y_ood,   expr_ood   = BooleanDataGenerator.generate_data(500;  variable_count=6, min_depth=5, max_depth=8, seed=3)


([0 1 … 0 1; 0 0 … 1 1; … ; 1 0 … 0 0; 0 1 … 0 0], [1, 0, 1, 1, 1, 1, 1, 0, 1, 1  …  1, 1, 0, 1, 1, 0, 1, 0, 0, 0], ["(OR (NAND (NOT (NAND x1 x5)) (XOR x6 (AND x2 x5))) (AND x5 x6))", "(NOT (NOT (OR x3 (OR x1 x3))))", "(XOR x5 (OR (XOR x4 (OR x6 x2)) x3))", "(NAND (OR (OR x3 x3) (NAND x1 x2)) (XOR (NAND (XOR x2 x2) x6) (OR x4 (NOT x2))))", "(NAND (NAND (NAND (OR x1 x3) (NOT x5)) x5) (NAND x4 x1))", "(NOT (XOR (NAND x6 (XOR x2 x1)) x4))", "(NOT (NOT (OR (NAND x1 x6) x5)))", "(AND x4 (OR (NAND x5 (NAND x6 x3)) (NOT (OR x4 x6))))", "(OR (XOR (XOR x4 (NAND x3 x3)) (AND x2 x3)) (NAND x2 x3))", "(NAND (XOR x5 x5) (XOR x4 (NOT (NAND x6 x4))))"  …  "(NAND x3 (AND (XOR (NAND x4 x3) x3) x5))", "(NAND (NAND x6 (XOR x6 x2)) (NOT (NOT (AND x3 x1))))", "(XOR (NAND x6 (NAND (XOR x6 x5) (NAND x4 x3))) (XOR x6 x2))", "(XOR x2 (OR (NAND x6 (NAND x4 x1)) (AND x4 x2)))", "(OR x5 (NOT (XOR (NAND x1 x2) x6)))", "(AND (NOT (XOR x2 (OR x4 x6))) (XOR x4 (XOR (AND x4 x2) (AND x6 (NAND x1 x1)))))", "(NOT (NOT (N

In [45]:
function tokenize_with_assignment(expression::String, variable_row::AbstractVector{<:Integer})
    spaced = replace(expression, r"([()])" => s" \1 ")
    raw_tokens = split(strip(spaced))
    tokens = String[]
    for t in raw_tokens
        if startswith(t, "x")
            idx = parse(Int, t[2:end])
            push!(tokens, "x$(idx)=$(variable_row[idx])")
        elseif t == "(" || t == ")"
            continue  # drop parentheses to shorten sequences
        else
            push!(tokens, t)  # "AND", "OR", "XOR", "NAND", "NOT"
        end
    end
    return tokens
end


# Build token sequences
function build_token_sequences(expressions::Vector{String}, X::Array{Int,2})
    seqs = Vector{Vector{String}}(undef, length(expressions))
    for i in eachindex(expressions)
        seqs[i] = tokenize_with_assignment(expressions[i], vec(X[i, :]))
    end
    return seqs
end

train_tokens = build_token_sequences(expr_train, X_train)
id_tokens    = build_token_sequences(expr_id,    X_id)
ood_tokens   = build_token_sequences(expr_ood,   X_ood)


500-element Vector{Vector{String}}:
 ["OR", "NAND", "NOT", "NAND", "x1=0", "x5=0", "XOR", "x6=1", "AND", "x2=1", "x5=0", "AND", "x5=0", "x6=1"]
 ["NOT", "NOT", "OR", "x3=0", "OR", "x1=0", "x3=0"]
 ["XOR", "x5=0", "OR", "XOR", "x4=0", "OR", "x6=0", "x2=1", "x3=0"]
 ["NAND", "OR", "OR", "x3=1", "x3=1", "NAND", "x1=1", "x2=0", "XOR", "NAND", "XOR", "x2=0", "x2=0", "x6=0", "OR", "x4=1", "NOT", "x2=0"]
 ["NAND", "NAND", "NAND", "OR", "x1=0", "x3=0", "NOT", "x5=1", "x5=1", "NAND", "x4=0", "x1=0"]
 ["NOT", "XOR", "NAND", "x6=1", "XOR", "x2=1", "x1=1", "x4=1"]
 ["NOT", "NOT", "OR", "NAND", "x1=1", "x6=1", "x5=1"]
 ["AND", "x4=0", "OR", "NAND", "x5=0", "NAND", "x6=1", "x3=1", "NOT", "OR", "x4=0", "x6=1"]
 ["OR", "XOR", "XOR", "x4=0", "NAND", "x3=1", "x3=1", "AND", "x2=0", "x3=1", "NAND", "x2=0", "x3=1"]
 ["NAND", "XOR", "x5=0", "x5=0", "XOR", "x4=0", "NOT", "NAND", "x6=1", "x4=0"]
 ⋮
 ["NAND", "NAND", "x6=0", "XOR", "x6=0", "x2=1", "NOT", "NOT", "AND", "x3=0", "x1=0"]
 ["XOR", "NAND", "x6=0", "

In [46]:
# Vocabulary from training tokens only
function build_vocab(token_sequences::Vector{Vector{String}})
    vocab = Dict{String,Int}()
    next_id = 1
    for seq in token_sequences
        for t in seq
            if !haskey(vocab, t)
                vocab[t] = next_id
                next_id += 1
            end
        end
    end
    return vocab
end

vocab = build_vocab(train_tokens)
unk_id = length(vocab) + 1  # just in case (should not be needed here)

# Map tokens to integer ids
token_to_id_tmp = t -> get(vocab, t, unk_id)
function map_to_ids(token_sequences::Vector{Vector{String}})
    return [map(token_to_id_tmp, seq) for seq in token_sequences]
end

train_ids = map_to_ids(train_tokens)
id_ids    = map_to_ids(id_tokens)
ood_ids   = map_to_ids(ood_tokens)


500-element Vector{Vector{Int64}}:
 [6, 9, 14, 9, 15, 3, 12, 2, 1, 10, 3, 1, 3, 2]
 [14, 14, 6, 13, 6, 15, 13]
 [12, 3, 6, 12, 11, 6, 7, 10, 13]
 [9, 6, 6, 8, 8, 9, 17, 4, 12, 9, 12, 4, 4, 7, 6, 5, 14, 4]
 [9, 9, 9, 6, 15, 13, 14, 16, 16, 9, 11, 15]
 [14, 12, 9, 2, 12, 10, 17, 5]
 [14, 14, 6, 9, 17, 2, 16]
 [1, 11, 6, 9, 3, 9, 2, 8, 14, 6, 11, 2]
 [6, 12, 12, 11, 9, 8, 8, 1, 4, 8, 9, 4, 8]
 [9, 12, 3, 3, 12, 11, 14, 9, 2, 11]
 ⋮
 [9, 9, 7, 12, 7, 10, 14, 14, 1, 13, 15]
 [12, 9, 7, 9, 12, 7, 3, 9, 11, 8, 12, 7, 10]
 [12, 10, 6, 9, 2, 9, 11, 15, 1, 11, 10]
 [6, 3, 14, 12, 9, 17, 4, 2]
 [1, 14, 12, 4, 6, 5, 7, 12, 5, 12, 1, 5, 4, 1, 7, 9, 15, 15]
 [14, 14, 9, 12, 2, 5, 12, 16, 10]
 [1, 7, 9, 1, 14, 3, 1, 15, 7, 14, 9, 11, 7]
 [14, 12, 12, 13, 7, 6, 6, 3, 5, 14, 5]
 [14, 9, 1, 1, 11, 3, 1, 3, 3, 12, 7, 13]

In [None]:
VARIABLE_COUNT = 6
N_A   = 4000   # Train Phase A (2-4)
N_B   = 8000   # Train Phase B (2-6)
N_ID  = 1000   # Test (2-4)
N_MID = 1000   # Test (5-6)
N_OOD = 1000   # Test (7-8)

Random.seed!(1)



cfg = (
    d_in   = 0,
    d_hid  = 128,
    d_out  = 1,
    N      = 4,
    T      = 0,
    batch  = 32,
    lr     = 1e-4,
    num_tokens = length(vocab),
    d_embed    = 128,
    l_heads    = 4,
    l_ff_mult  = 4,
    h_heads    = 4,
    h_ff_mult  = 4,
    dropout    = 0.15,
    pad_id     = PAD_ID
)



function make_depth_stratified_split(depth_counts::Dict{Int,Int};
    variable_count::Int, seed::Int)

    Xs = Matrix{Int}[]; ys = Vector{Int}[]; Es = Vector{String}[]
    for d in sort(collect(keys(depth_counts)))
        n = depth_counts[d]; n == 0 && continue
        X, y, E = BooleanDataGenerator.generate_data(n;
                            variable_count=variable_count, min_depth=d, max_depth=d, seed=seed + d,
                            mode = :exact)
        push!(Xs, X); push!(ys, y); push!(Es, E)
    end
    X_all = vcat(Xs...)                 # (N_total, variable_count)
    y_all = reduce(vcat, ys)            # (N_total,)
    E_all = reduce(vcat, Es)            # Vector{String} length N_total

    # Shuffle ROWS (samples) — NOT columns
    perm = randperm(size(X_all, 1))
    return X_all[perm, :], y_all[perm], E_all[perm]
end



# Train pools
# Train Phase A: (2,3,4)
X_tr_A, y_tr_A, expr_tr_A =
    make_depth_stratified_split(Dict(2=>N_A÷3, 3=>N_A÷3, 4=>N_A - 2*(N_A÷3));
                                variable_count=VARIABLE_COUNT, seed=101)

# Train Phase B: (2,3,4,5,6) equal-ish
X_tr_B, y_tr_B, expr_tr_B =
    make_depth_stratified_split(Dict(2=>N_B÷5, 3=>N_B÷5, 4=>N_B÷5, 5=>N_B÷5, 6=>N_B - 4*(N_B÷5));
                                variable_count=VARIABLE_COUNT, seed=102)

# Test ID: (2,3,4)
X_te_ID,  y_te_ID,  expr_te_ID  =
    make_depth_stratified_split(Dict(2=>N_ID÷3, 3=>N_ID÷3, 4=>N_ID - 2*(N_ID÷3));
                                variable_count=VARIABLE_COUNT, seed=201)

# Test MID: (5,6)
X_te_MID, y_te_MID, expr_te_MID =
    make_depth_stratified_split(Dict(5=>N_MID÷2, 6=>N_MID - (N_MID÷2));
                                variable_count=VARIABLE_COUNT, seed=202)

# Test OOD: (7,8)
X_te_OOD, y_te_OOD, expr_te_OOD =
    make_depth_stratified_split(Dict(7=>N_OOD÷2, 8=>N_OOD - (N_OOD÷2));
                                variable_count=VARIABLE_COUNT, seed=203)

let d = [max_paren_depth(e)+1 for e in expr_te_OOD]
    @assert all(x -> x==7 || x==8, d) "Found non-(7,8) depths in OOD split: $(StatsBase.countmap(d))"
end


function max_paren_depth(expr::String)
    d = 0; m = 0
    @inbounds for c in expr
        if c == '('
            d += 1; m = max(m, d)
        elseif c == ')'
            d -= 1
        end
    end
    return m
end


# 2) Tokenize with assignment, build vocab, pad to fixed length
function tokenize_with_assignment(expression::String, variable_row::AbstractVector{<:Integer})
    spaced = replace(expression, r"([()])" => s" \1 ")
    raw_tokens = split(strip(spaced))
    tokens = String[]
    for t in raw_tokens
        if startswith(t, "x")
            idx = parse(Int, t[2:end])
            push!(tokens, "x$(idx)=$(variable_row[idx])")
        elseif t == "(" || t == ")"
            continue  # drop parentheses to shorten sequences (consistent with earlier runs)
        else
            push!(tokens, t)  # "AND", "OR", "XOR", "NAND", "NOT"
        end
    end
    return tokens
end


function build_token_sequences(expressions::Vector{String}, X::Array{Int,2})
    seqs = Vector{Vector{String}}(undef, length(expressions))
    for i in eachindex(expressions)
        seqs[i] = tokenize_with_assignment(expressions[i], vec(X[i, :]))
    end
    return seqs
end


# Tokenize every split
tokens_tr_A  = build_token_sequences(expr_tr_A,  X_tr_A)
tokens_tr_B  = build_token_sequences(expr_tr_B,  X_tr_B)
tokens_te_ID = build_token_sequences(expr_te_ID, X_te_ID)
tokens_te_M  = build_token_sequences(expr_te_MID,X_te_MID)
tokens_te_O  = build_token_sequences(expr_te_OOD,X_te_OOD)

function build_vocab_with_specials(train_token_sequences::Vector{Vector{String}})
    vocab = Dict{String,Int}("<PAD>"=>1, "<UNK>"=>2)
    next_id = 3
    for seq in train_token_sequences
        for t in seq
            haskey(vocab, t) || (vocab[t] = next_id; next_id += 1)
        end
    end
    return vocab
end

vocab = build_vocab_with_specials(tokens_tr_A)
const PAD_ID = vocab["<PAD>"]
const UNK_ID = vocab["<UNK>"]
token_to_id(t) = get(vocab, t, UNK_ID)

# --- Map tokens -> ids (produces (L, N) matrices) ---
function to_ids(tokens::Vector{Vector{String}})
    Lmax = maximum(length.(tokens))
    n = length(tokens)
    Xids = fill(PAD_ID, Lmax, n)
    @inbounds for i in 1:n
        s = tokens[i]
        for j in 1:length(s)
            Xids[j, i] = token_to_id(s[j])
        end
    end
    return Xids
end

Xids_tr_A  = to_ids(tokens_tr_A)
Xids_tr_B  = to_ids(tokens_tr_B)
Xids_te_ID = to_ids(tokens_te_ID)
Xids_te_M  = to_ids(tokens_te_M)
Xids_te_O  = to_ids(tokens_te_O)


# Compute the maximum sequence length across *all* splits
L_A   = size(Xids_tr_A,  1)
L_B   = size(Xids_tr_B,  1)
L_ID  = size(Xids_te_ID, 1)
L_MID = size(Xids_te_M,  1)
L_OOD = size(Xids_te_O,  1)

# Vocabulary from training only, with <PAD>=1 and <UNK>=2
function build_vocab_with_specials(train_token_sequences::Vector{Vector{String}})
    vocab = Dict{String,Int}("<PAD>"=>1, "<UNK>"=>2)
    next_id = 3
    for seq in train_token_sequences
        for t in seq
            haskey(vocab, t) || (vocab[t] = next_id; next_id += 1)
        end
    end
    return vocab
end


# Vocab ONLY from Phase-A training (guards against leakage)
vocab = build_vocab_with_specials(tokens_tr_A)
const PAD_ID = vocab["<PAD>"]
const UNK_ID = vocab["<UNK>"]
token_to_id(t) = get(vocab, t, UNK_ID)


# Per-split padding (no global max)
function to_ids(tokens::Vector{Vector{String}})
    Lmax = maximum(length.(tokens))
    n = length(tokens)
    Xids = fill(PAD_ID, Lmax, n)
    @inbounds for i in 1:n
        s = tokens[i]
        for j in 1:length(s)
            Xids[j, i] = token_to_id(s[j])
        end
    end
    return Xids
end





depths(v) = [max_paren_depth(e)+1 for e in v]

@show StatsBase.countmap(depths(expr_tr_A))
@show StatsBase.countmap(depths(expr_tr_B))
@show StatsBase.countmap(depths(expr_te_ID))
@show StatsBase.countmap(depths(expr_te_MID))
@show StatsBase.countmap(depths(expr_te_OOD))








# +1 if you use CLS (the CLS token is prepended to H_in)
POS_L_MAX = max(L_A, L_B, L_ID, L_MID, L_OOD) + 1

models = HRMFlux.build_models(cfg; positional_encoding_kind=:sinusoidal, pos_L_max=POS_L_MAX)




function seq_lengths_from_pad_id(Xids::AbstractMatrix{<:Integer}, pad_id::Int)
    L, B = size(Xids)
    lengths = fill(L, B)
    @inbounds for b in 1:B
        for t in 1:L
            if Xids[t, b] == pad_id
                lengths[b] = t - 1
                break
            end
        end
    end
    return lengths
end



# keep PAD embedding at zero to avoid length bias
if models.tok_emb !== nothing
    models.tok_emb.weight[:, PAD_ID] .= 0f0
end


function each_minibatch_bucketed(X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer};
                                 batch::Int, buckets::Int=6, pad_id::Int=PAD_ID)
    L, N = size(X)
    lens = fill(L, N)
    @inbounds for i in 1:N, t in 1:L
        if X[t,i] == pad_id
            lens[i] = t-1
            break
        end
    end
    order  = sortperm(1:N, by=i->lens[i])
    groups = Iterators.partition(order, max(1, ceil(Int, N/buckets)))
    out = Vector{Tuple{Matrix{Int}, Vector{Int}}}()
    for g in groups
        idx = collect(g); Random.shuffle!(idx)
        for k in 1:batch:length(idx)
            sel = idx[k:min(k+batch-1, length(idx))]
            push!(out, (X[:, sel], y[sel]))
        end
    end
    return out
end


function each_minibatch(X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer}, batch_size::Int)
    idx = collect(1:size(X, 2))
    Random.shuffle!(idx)
    out = Vector{Tuple{Matrix{Int}, Vector{Int}}}()
    for k in 1:batch_size:length(idx)
        sel = idx[k:min(k+batch_size-1, length(idx))]
        push!(out, (X[:, sel], y[sel]))
    end
    return out
end


"Yield a mixed stream of mini-batches from A and B with probability pB for B."
function each_minibatch_mixed(XA::AbstractMatrix{<:Integer}, yA::AbstractVector{<:Integer},
                              XB::AbstractMatrix{<:Integer}, yB::AbstractVector{<:Integer};
                              batch::Int, pB::Float64, buckets::Int=6, pad_id::Int=PAD_ID)

    itA = each_minibatch_bucketed(XA, yA; batch=batch, buckets=buckets, pad_id=pad_id)
    itB = each_minibatch_bucketed(XB, yB; batch=batch, buckets=buckets, pad_id=pad_id)

    ia = 1; ib = 1
    out = Vector{Tuple{Matrix{Int}, Vector{Int}}}()
    while ia <= length(itA) || ib <= length(itB)
        useB = rand() < pB
        if useB && ib <= length(itB)
            push!(out, itB[ib]); ib += 1
        elseif ia <= length(itA)
            push!(out, itA[ia]); ia += 1
        elseif ib <= length(itB)
            push!(out, itB[ib]); ib += 1
        end
    end
    return out
end



function batch_loss(models, x_batch::AbstractMatrix{<:Integer}, y_batch::AbstractVector{<:Integer}, cfg)
    batch_size = size(x_batch, 2)
    low_state, high_state = HRMFlux.init_states(batch_size, cfg.d_hid)
    
    # yhat, _, _ = HRMFlux.run_segment_GRU!(models, x_batch, low_state, high_state; N=cfg.N, T=cfg.T, cfg=cfg)
    yhat, _, _ = HRMFlux.run_sequence_segment!(models, x_batch, low_state, high_state; N=cfg.N, cfg=cfg)

    targets = reshape(Float32.(y_batch), 1, batch_size)
    return Flux.logitbinarycrossentropy(yhat, targets)
end


# function accuracy(models, X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer}, cfg; batch_size::Int=256)
#     correct = 0
#     total = 0
#     for (xb, yb) in each_minibatch(X, y, batch_size)
#         bs = size(xb, 2)
#         low_state, high_state = HRMFlux.init_states(bs, cfg.d_hid)
#         yhat, _, _ = HRMFlux.run_sequence_segment!(models, xb, low_state, high_state; N=cfg.N, cfg=cfg)
#         preds = @. yhat > 0
#         correct += sum(Int.(preds[1, :]) .== yb)
#         total   += bs
#     end
#     return correct / total
# end


function accuracy(models, X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer}, cfg; batch_size::Int=256)
    correct = 0
    total = 0
    # use the function arguments X, y — not Xtr, ytr
    for (xb, yb) in each_minibatch_bucketed(X, y; batch=min(cfg.batch, batch_size), buckets=6)
        bs = size(xb, 2)
        low_state, high_state = HRMFlux.init_states(bs, cfg.d_hid)

        yhat, _, _ = HRMFlux.run_sequence_segment!(models, xb, low_state, high_state; N=cfg.N, cfg=cfg)
        preds = @. yhat > 0
        correct += sum(Int.(preds[1, :]) .== yb)
        total   += bs
    end
    return correct / total
end




function accuracy_by_depth(models, Xids, y, exprs; cfg, batch_size=256)
    de = [max_paren_depth(e) + 1 for e in exprs]  # structural depth
    mp = Dict{Int, Vector{Int}}()
    for (i,d) in enumerate(de)
        push!(get!(mp, d, Int[]), i)
    end
    out = Dict{Int,Float64}()
    for (d, idx) in sort(collect(mp); by=first)
        out[d] = accuracy(models, Xids[:, idx], y[idx], cfg; batch_size=batch_size)
    end
    return out
end



@assert size(X_tr_A,2) == VARIABLE_COUNT
@assert length(y_tr_A) == size(X_tr_A,1) == length(expr_tr_A)
@assert length(y_tr_B) == size(X_tr_B,1) == length(expr_tr_B)
@assert length(y_te_ID)  == size(X_te_ID,1)  == length(expr_te_ID)
@assert length(y_te_MID) == size(X_te_MID,1) == length(expr_te_MID)
@assert length(y_te_OOD) == size(X_te_OOD,1) == length(expr_te_OOD)

@assert size(Xids_tr_A,2) == size(X_tr_A,1)
@assert size(Xids_te_O, 2) == size(X_te_OOD,1)



# 5) Train and evaluate
# curriculum phases
epochs_A = 10    # Phase A (2-4)
epochs_B = 10    # Phase B (2-6)
total_epochs = epochs_A + epochs_B

# opt_state = Optimisers.setup(Optimisers.Adam(cfg.lr), models)
base_lr = cfg.lr
opt = Optimisers.Adam(base_lr)
opt_state = Optimisers.setup(opt, models)


for epoch in 1:total_epochs
    total_loss = 0.0; batches = 0

    # mixing schedule: 0.0 during Phase A, then ramp A→B across Phase B
    pB = epoch <= epochs_A ? 0.0 : (epoch - epochs_A) / epochs_B
    pB = clamp(pB, 0.3, 1.0)  # keep at least some B once Phase B starts

    

    for (xb, yb) in each_minibatch_mixed(Xids_tr_A, y_tr_A, Xids_tr_B, y_tr_B;
                                        batch=cfg.batch, pB=pB, buckets=6, pad_id=PAD_ID)

        # keep PAD column neutral (before & after update)
        if hasproperty(models, :tok_emb) && models.tok_emb !== nothing
            models.tok_emb.weight[:, PAD_ID] .= 0f0
        end

        L, back = Zygote.pullback(m -> batch_loss(m, xb, yb, cfg), models)
        grads = back(one(L))[1]
        opt_state, models = Optimisers.update(opt_state, models, grads)

        if hasproperty(models, :tok_emb) && models.tok_emb !== nothing
            models.tok_emb.weight[:, PAD_ID] .= 0f0
        end

        total_loss += Float64(L); batches += 1
    end

    acc_id  = accuracy(models, Xids_te_ID,  y_te_ID,  cfg)   # 2-4
    acc_mid = accuracy(models, Xids_te_M,   y_te_MID, cfg)   # 5-6
    acc_ood = accuracy(models, Xids_te_O,   y_te_OOD, cfg)   # 7-8
    @info "epoch=$(epoch)  loss=$(round(total_loss/batches, digits=4))  " *
          "ID(2-4)=$(round(acc_id,digits=3))  MID(5-6)=$(round(acc_mid,digits=3))  OOD(7-8)=$(round(acc_ood,digits=3))"

    ood = accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg=cfg)
    acc7 = get(ood, 7, NaN); acc8 = get(ood, 8, NaN)
    ood_macro = (acc7 + acc8) / 2
    @info "OOD macro (7/8 equally) = $(round(ood_macro, digits=3))"

    if epoch % 5 == 0
        @show accuracy_by_depth(models, Xids_te_O,  y_te_OOD,  expr_te_OOD;  cfg=cfg)
        @show accuracy_by_depth(models, Xids_te_M,  y_te_MID,  expr_te_MID;  cfg=cfg)
        println("Final depth bins (OOD 7-8):")
        @show accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg=cfg)
    end
end

println("Final depth bins (OOD 7-8):")
@show accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg=cfg)

println("Final depth bins (MID 5-6):")
@show accuracy_by_depth(models, Xids_te_M, y_te_MID, expr_te_MID; cfg=cfg)

println("Done.")

StatsBase.countmap(depths(expr_tr_A)) = Dict(4 => 668, 2 => 666, 3 => 666)
StatsBase.countmap(depths(expr_tr_B)) = Dict(5 => 800, 4 => 800, 6 => 800, 2 => 800, 3 => 800)
StatsBase.countmap(depths(expr_te_ID)) = Dict(4 => 334, 2 => 333, 3 => 333)
StatsBase.countmap(depths(expr_te_MID)) = Dict(5 => 500, 6 => 500)
StatsBase.countmap(depths(expr_te_OOD)) = Dict(7 => 500, 8 => 500)


┌ Info: epoch=1  loss=0.7071  ID(2-4)=0.554  MID(5-6)=0.611  OOD(7-8)=0.612
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.612
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=2  loss=0.6526  ID(2-4)=0.596  MID(5-6)=0.63  OOD(7-8)=0.631
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.631
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=3  loss=0.6447  ID(2-4)=0.69  MID(5-6)=0.628  OOD(7-8)=0.628
└ @ Main /home/resort/Documents/repos/JuliaExplore

accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg = cfg) = Dict(7 => 0.658, 8 => 0.652)
accuracy_by_depth(models, Xids_te_M, y_te_MID, expr_te_MID; cfg = cfg) = Dict(5 => 0.654, 6 => 0.72)
Final depth bins (OOD 7-8):
accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg = cfg) = Dict(7 => 0.658, 8 => 0.652)


┌ Info: epoch=6  loss=0.4287  ID(2-4)=0.817  MID(5-6)=0.669  OOD(7-8)=0.643
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.643
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=7  loss=0.3918  ID(2-4)=0.809  MID(5-6)=0.687  OOD(7-8)=0.678
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.678
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=8  loss=0.3715  ID(2-4)=0.843  MID(5-6)=0.668  OOD(7-8)=0.622
└ @ Main /home/resort/Documents/repos/JuliaExplo

accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg = cfg) = Dict(7 => 0.646, 8 => 0.668)
accuracy_by_depth(models, Xids_te_M, y_te_MID, expr_te_MID; cfg = cfg) = Dict(5 => 0.692, 6 => 0.662)
Final depth bins (OOD 7-8):
accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg = cfg) = Dict(7 => 0.646, 8 => 0.668)


┌ Info: epoch=11  loss=0.3355  ID(2-4)=0.866  MID(5-6)=0.692  OOD(7-8)=0.665
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.665
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=12  loss=0.3179  ID(2-4)=0.856  MID(5-6)=0.701  OOD(7-8)=0.665
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.665
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=13  loss=0.3124  ID(2-4)=0.871  MID(5-6)=0.722  OOD(7-8)=0.676
└ @ Main /home/resort/Documents/repos/JuliaEx

accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg = cfg) = Dict(7 => 0.678, 8 => 0.68)
accuracy_by_depth(models, Xids_te_M, y_te_MID, expr_te_MID; cfg = cfg) = Dict(5 => 0.708, 6 => 0.722)
Final depth bins (OOD 7-8):
accuracy_by_depth(models, Xids_te_O, y_te_OOD, expr_te_OOD; cfg = cfg) = Dict(7 => 0.678, 8 => 0.68)


┌ Info: epoch=16  loss=0.2876  ID(2-4)=0.884  MID(5-6)=0.7  OOD(7-8)=0.686
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.686
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443
┌ Info: epoch=17  loss=0.2789  ID(2-4)=0.893  MID(5-6)=0.712  OOD(7-8)=0.677
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:437
┌ Info: OOD macro (7/8 equally) = 0.677
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:443


InterruptException: InterruptException:

In [55]:
post_ID  = countmap([max_paren_depth(e)+1 for e in expr_te_ID])
post_MID = countmap([max_paren_depth(e)+1 for e in expr_te_MID])
post_OOD = countmap([max_paren_depth(e)+1 for e in expr_te_OOD])

@show post_ID
@show post_MID
@show post_OOD

post_ID = Dict(4 => 334, 2 => 333, 3 => 333)
post_MID = Dict(5 => 500, 6 => 500)
post_OOD = Dict(7 => 500, 8 => 500)


Dict{Int64, Int64} with 2 entries:
  7 => 500
  8 => 500

In [56]:
max_paren_depth(s) = (d=0;m=0; for c in s; c=='(' && (d+=1; m=max(m,d)); c==')' && (d-=1); end; m)
depths(v) = [max_paren_depth(e)+1 for e in v]
using StatsBase: countmap
@show countmap(depths(expr_te_OOD))  # expect Dict(7=>…, 8=>…)

countmap(depths(expr_te_OOD)) = Dict(7 => 500, 8 => 500)




Dict{Int64, Int64} with 2 entries:
  7 => 500
  8 => 500