In [1]:
using Pkg
Pkg.activate("..")  #one level up, where Project.toml lives
Pkg.instantiate()   #download/install anything missing
# Pkg.status();

[32m[1m  Activating[22m[39m project at `~/Documents/repos/JuliaExploreHRM`


In [2]:
include(joinpath(@__DIR__, "..", "data", "nested_boolean_gen.jl"))
include(joinpath(@__DIR__, "..", "data", "hrm_common_nested_boolean_FLUX.jl"))

using .BooleanDataGenerator
using .HRMFlux

using Random, Statistics
using Flux, Zygote, Optimisers
using Flux: onehotbatch, onecold

In [3]:
# Training data (depth 2-4)
X_train, y_trainainainain, _ = generate_data(100; min_depth=2, max_depth=4)

# Test data (depth 5-8) 
X_test, y_test, _ = generate_data(20; min_depth=5, max_depth=8)

# Test data (held-out NAND)
X_test_ops, y_test_ops, _ = generate_data(20; held_out_ops=[:NAND])

([1 1 … 1 1; 1 0 … 0 0; … ; 0 1 … 1 0; 0 1 … 0 1], [1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1], ["(NOT (AND x1 (XOR x1 x1)))", "(AND x5 x2)", "(OR x1 x3)", "(OR x5 (AND x4 x1))", "(OR x5 (XOR x3 x4))", "(XOR x3 x2)", "(AND x5 x4)", "(OR (XOR x2 x3) (OR x5 x4))", "(AND x2 (NOT x5))", "(OR x3 (OR x4 x5))", "(AND x4 (AND (XOR x4 x3) x3))", "(AND (OR x3 x5) x3)", "(NOT (XOR x1 x5))", "(NOT (OR (XOR x5 x5) x1))", "(XOR x1 x1)", "(XOR x3 x1)", "(OR x4 (XOR x4 x1))", "(NOT x5)", "(XOR x3 x2)", "(OR x2 x1)"])

In [4]:
# Generate training data (depth 2-4)
X_train, y_train, expr_train = generate_data(100; min_depth=2, max_depth=4, seed=42)

# Generate test data with depth generalization (depth 5-8)
X_test, y_test, expr_test = generate_data(20; min_depth=5, max_depth=8, seed=123)

# Generate test data with held-out operations (no NAND)
X_test_ops, y_test_ops, expr_test_ops = generate_data(20; held_out_ops=[:NAND], seed=456)

println("Training: X=$(size(X_train)), y=$(size(y_train))")
println("Test (depth): X=$(size(X_test)), y=$(size(y_test))")
println("Test (ops): X=$(size(X_test_ops)), y=$(size(y_test_ops))")

# Show a few examples
println("\nTraining examples:")
for i in 1:3
    println("$(expr_train[i]) | vars=$(X_train[i,:]) → $(y_train[i])")
end


Training: X=(100, 5), y=(100,)
Test (depth): X=(20, 5), y=(20,)
Test (ops): X=(20, 5), y=(20,)

Training examples:
(NOT (AND x1 (NAND x1 x1))) | vars=[1, 1, 1, 1, 1] → 1
(AND x5 x2) | vars=[1, 0, 0, 0, 0] → 0
(OR x1 x3) | vars=[1, 1, 1, 0, 1] → 1


In [5]:
# Train (depth 2-4), Test-ID (2-4), Test-OOD (5-8)
X_train, y_train, expr_train = BooleanDataGenerator.generate_data(2000; variable_count=6, min_depth=2, max_depth=4, seed=1)
X_id,    y_test_id,    expr_id    = BooleanDataGenerator.generate_data(500;  variable_count=6, min_depth=2, max_depth=4, seed=2)
X_ood,   y_ood,   expr_ood   = BooleanDataGenerator.generate_data(500;  variable_count=6, min_depth=5, max_depth=8, seed=3)


([0 1 … 0 1; 0 0 … 1 1; … ; 1 0 … 0 0; 0 1 … 0 0], [0, 0, 1, 1, 0, 1, 1, 0, 0, 1  …  1, 1, 1, 1, 0, 1, 0, 0, 0, 1], ["(AND (XOR (OR x5 x5) (AND (AND x2 x5) (AND x5 x6))) (NOT (XOR x6 (NOT x4))))", "(NOT (OR x3 (OR x3 (NOT x4))))", "(OR (XOR x4 (OR (OR x6 x2) (AND x5 x1))) x2)", "(NOT (XOR x4 (NOT (AND x2 x6))))", "(NAND (XOR x6 (OR (NOT (OR x2 x2)) (NAND x3 x1))) (NAND x4 x6))", "(XOR (NOT (XOR (AND x6 x6) x3)) x1)", "(OR (OR x1 (AND x6 x2)) (AND (AND x3 x2) (XOR x6 (XOR x2 x4))))", "(NOT (XOR x1 (NAND (NAND x6 x6) x2)))", "(AND x4 (NOT (NAND (NAND x1 x6) x4)))", "(OR x6 (OR (OR x5 (NAND (OR x1 x2) (NOT x6))) (NOT (OR x4 x4))))"  …  "(XOR x2 (XOR (NAND x5 (XOR (AND x2 x3) x1)) (OR x5 x3)))", "(AND (NAND x1 (NAND (NOT x4) (NAND (NAND x6 x4) (NOT (NAND x4 x5))))) x4)", "(XOR (XOR (XOR x3 (XOR x1 x1)) x3) x2)", "(NOT (NOT (OR (NAND (NOT x4) x6) (AND x6 x2))))", "(AND (OR (OR (XOR x3 x1) (OR (XOR x6 x6) x2)) (AND (AND x6 x4) x2)) (NAND x1 x2))", "(NOT (NOT (NOT (XOR x5 x5))))", "(XOR (AND 

In [6]:
function tokenize_with_assignment(expression::String, variable_row::Vector{Int})
    spaced = replace(expression, r"([()])" => s" \1 ")
    raw_tokens = split(strip(spaced))
    tokens = String[]
    for t in raw_tokens
        if startswith(t, "x")
            idx = parse(Int, t[2:end])
            push!(tokens, "x$(idx)=$(variable_row[idx])")
        else
            push!(tokens, t)
        end
    end
    return tokens
end

# Build token sequences
function build_token_sequences(expressions::Vector{String}, X::Array{Int,2})
    seqs = Vector{Vector{String}}(undef, length(expressions))
    for i in eachindex(expressions)
        seqs[i] = tokenize_with_assignment(expressions[i], vec(X[i, :]))
    end
    return seqs
end

train_tokens = build_token_sequences(expr_train, X_train)
id_tokens    = build_token_sequences(expr_id,    X_id)
ood_tokens   = build_token_sequences(expr_ood,   X_ood)


500-element Vector{Vector{String}}:
 ["(", "AND", "(", "XOR", "(", "OR", "x5=0", "x5=0", ")", "("  …  "(", "XOR", "x6=1", "(", "NOT", "x4=1", ")", ")", ")", ")"]
 ["(", "NOT", "(", "OR", "x3=0", "(", "OR", "x3=0", "(", "NOT", "x4=0", ")", ")", ")", ")"]
 ["(", "OR", "(", "XOR", "x4=0", "(", "OR", "(", "OR", "x6=0"  …  ")", "(", "AND", "x5=0", "x1=0", ")", ")", ")", "x2=1", ")"]
 ["(", "NOT", "(", "XOR", "x4=1", "(", "NOT", "(", "AND", "x2=0", "x6=0", ")", ")", ")", ")"]
 ["(", "NAND", "(", "XOR", "x6=0", "(", "OR", "(", "NOT", "("  …  "x1=0", ")", ")", ")", "(", "NAND", "x4=0", "x6=0", ")", ")"]
 ["(", "XOR", "(", "NOT", "(", "XOR", "(", "AND", "x6=1", "x6=1", ")", "x3=0", ")", ")", "x1=1", ")"]
 ["(", "OR", "(", "OR", "x1=1", "(", "AND", "x6=1", "x2=1", ")"  …  "XOR", "x6=1", "(", "XOR", "x2=1", "x4=1", ")", ")", ")", ")"]
 ["(", "NOT", "(", "XOR", "x1=0", "(", "NAND", "(", "NAND", "x6=1", "x6=1", ")", "x2=1", ")", ")", ")"]
 ["(", "AND", "x4=0", "(", "NOT", "(", "NAND", "(", "NAND", 

In [7]:
# Vocabulary from training tokens only
function build_vocab(token_sequences::Vector{Vector{String}})
    vocab = Dict{String,Int}()
    next_id = 1
    for seq in token_sequences
        for t in seq
            if !haskey(vocab, t)
                vocab[t] = next_id
                next_id += 1
            end
        end
    end
    return vocab
end

vocab = build_vocab(train_tokens)
unk_id = length(vocab) + 1  # just in case (should not be needed here)

# Map tokens to integer ids
token_to_id_tmp = t -> get(vocab, t, unk_id)
function map_to_ids(token_sequences::Vector{Vector{String}})
    return [map(token_to_id_tmp, seq) for seq in token_sequences]
end

train_ids = map_to_ids(train_tokens)
id_ids    = map_to_ids(id_tokens)
ood_ids   = map_to_ids(ood_tokens)


500-element Vector{Vector{Int64}}:
 [1, 2, 1, 14, 1, 8, 4, 4, 5, 1  …  1, 14, 3, 1, 16, 7, 5, 5, 5, 5]
 [1, 16, 1, 8, 15, 1, 8, 15, 1, 16, 13, 5, 5, 5, 5]
 [1, 8, 1, 14, 13, 1, 8, 1, 8, 9  …  5, 1, 2, 4, 18, 5, 5, 5, 12, 5]
 [1, 16, 1, 14, 7, 1, 16, 1, 2, 6, 9, 5, 5, 5, 5]
 [1, 11, 1, 14, 9, 1, 8, 1, 16, 1  …  18, 5, 5, 5, 1, 11, 13, 9, 5, 5]
 [1, 14, 1, 16, 1, 14, 1, 2, 3, 3, 5, 15, 5, 5, 19, 5]
 [1, 8, 1, 8, 19, 1, 2, 3, 12, 5  …  14, 3, 1, 14, 12, 7, 5, 5, 5, 5]
 [1, 16, 1, 14, 18, 1, 11, 1, 11, 3, 3, 5, 12, 5, 5, 5]
 [1, 2, 13, 1, 16, 1, 11, 1, 11, 18, 9, 5, 13, 5, 5, 5]
 [1, 8, 3, 1, 8, 1, 8, 4, 1, 11  …  1, 16, 1, 8, 13, 13, 5, 5, 5, 5]
 ⋮
 [1, 2, 1, 11, 18, 1, 11, 1, 16, 7  …  11, 7, 17, 5, 5, 5, 5, 5, 7, 5]
 [1, 14, 1, 14, 1, 14, 10, 1, 14, 18, 18, 5, 5, 10, 5, 12, 5]
 [1, 16, 1, 16, 1, 8, 1, 11, 1, 16  …  3, 5, 1, 2, 3, 12, 5, 5, 5, 5]
 [1, 2, 1, 8, 1, 8, 1, 14, 10, 19  …  5, 6, 5, 5, 1, 11, 19, 6, 5, 5]
 [1, 16, 1, 16, 1, 16, 1, 14, 4, 4, 5, 5, 5, 5]
 [1, 14, 1, 2, 1, 14, 17,

In [10]:
 VARIABLE_COUNT = 6
 TRAIN_SIZE     = 2000
 TEST_ID_SIZE   = 500
 TEST_OOD_SIZE  = 500

 DEPTH_TRAIN_MIN = 2
 DEPTH_TRAIN_MAX = 4
 DEPTH_OOD_MIN   = 5
 DEPTH_OOD_MAX   = 8

Random.seed!(1)

X_train, y_train, expr_train = BooleanDataGenerator.generate_data(
    TRAIN_SIZE; variable_count=VARIABLE_COUNT, min_depth=DEPTH_TRAIN_MIN, max_depth=DEPTH_TRAIN_MAX, seed=1)

X_test_id, y_test_id, expr_test_id = BooleanDataGenerator.generate_data(
    TEST_ID_SIZE; variable_count=VARIABLE_COUNT, min_depth=DEPTH_TRAIN_MIN, max_depth=DEPTH_TRAIN_MAX, seed=2)

X_test_ood, y_test_ood, expr_test_ood = BooleanDataGenerator.generate_data(
    TEST_OOD_SIZE; variable_count=VARIABLE_COUNT, min_depth=DEPTH_OOD_MIN, max_depth=DEPTH_OOD_MAX, seed=3)

# 2) Tokenize with assignment, build vocab, pad to fixed length
function tokenize_with_assignment(expression::String, variable_row::AbstractVector{<:Integer})
    spaced = replace(expression, r"([()])" => s" \1 ")
    raw_tokens = split(strip(spaced))
    tokens = String[]
    for t in raw_tokens
        if startswith(t, "x")
            idx = parse(Int, t[2:end])
            push!(tokens, "x$(idx)=$(variable_row[idx])")
        else
            push!(tokens, t)  # "(", ")", "AND", "OR", "XOR", "NAND", "NOT"
        end
    end
    return tokens
end

function build_token_sequences(expressions::Vector{String}, X::Array{Int,2})
    seqs = Vector{Vector{String}}(undef, length(expressions))
    for i in eachindex(expressions)
        seqs[i] = tokenize_with_assignment(expressions[i], vec(X[i, :]))
    end
    return seqs
end

tokens_train = build_token_sequences(expr_train,     X_train)
tokens_id    = build_token_sequences(expr_test_id,   X_test_id)
tokens_ood   = build_token_sequences(expr_test_ood,  X_test_ood)

# Vocabulary from training only, with <PAD>=1 and <UNK>=2
function build_vocab_with_specials(train_token_sequences::Vector{Vector{String}})
    vocab = Dict{String,Int}("<PAD>"=>1, "<UNK>"=>2)
    next_id = 3
    for seq in train_token_sequences
        for t in seq
            haskey(vocab, t) || (vocab[t] = next_id; next_id += 1)
        end
    end
    return vocab
end

vocab = build_vocab_with_specials(tokens_train)
const PAD_ID = vocab["<PAD>"]
const UNK_ID = vocab["<UNK>"]

global_max_len = maximum((
    maximum(length.(tokens_train)),
    maximum(length.(tokens_id)),
    maximum(length.(tokens_ood))
))

token_to_id(t) = get(vocab, t, UNK_ID)

function to_padded_id_matrix(token_sequences::Vector{Vector{String}}, max_len::Int)
    n = length(token_sequences)
    Xids = fill(PAD_ID, max_len, n)  # (sequence_len, n_samples)
    for i in 1:n
        seq = token_sequences[i]
        L = min(length(seq), max_len)
        @inbounds for j in 1:L
            Xids[j, i] = token_to_id(seq[j])
        end
    end
    return Xids
end

token_ids_train = to_padded_id_matrix(tokens_train, global_max_len)
token_ids_id    = to_padded_id_matrix(tokens_id,    global_max_len)
token_ids_ood   = to_padded_id_matrix(tokens_ood,   global_max_len)

# 3) Build HRM (classification setup)
cfg = (
    d_in   = 0,     # sequence length = columns of x_in
    d_hid  = 128,
    d_out  = 1,                  # binary logit
    N      = 3,                  # outer cycles
    T      = 0,                  # inner low-state rollouts
    batch  = 64,
    lr     = 2e-3,

    # token-ID path
    num_tokens = length(vocab),  # includes PAD and UNK
    d_embed    = 96,

    # transformer hyperparameters
    l_heads    = 2,
    l_ff_mult  = 4,
    h_heads    = 2,
    h_ff_mult  = 4,
    dropout    = 0.0,
    pad_id     = PAD_ID
)


function seq_lengths_from_pad_id(Xids::AbstractMatrix{<:Integer}, pad_id::Int)
    L, B = size(Xids)
    lengths = fill(L, B)
    @inbounds for b in 1:B
        for t in 1:L
            if Xids[t, b] == pad_id
                lengths[b] = t - 1
                break
            end
        end
    end
    return lengths
end


# models = HRMFlux.build_models(cfg; positional_encoding_kind = :sinusoidal, pos_L_max = global_max_len)
models = HRMFlux.build_models_GRU(cfg; positional_encoding_kind = :sinusoidal, pos_L_max = global_max_len)

# keep PAD embedding at zero to avoid length bias
if models.tok_emb !== nothing
    models.tok_emb.weight[:, PAD_ID] .= 0f0
end

# 4) Minibatching, loss, accuracy
function each_minibatch(X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer}, batch_size::Int)
    idx = collect(1:size(X, 2))
    Random.shuffle!(idx)
    batches = Vector{Tuple{Matrix{Int}, Vector{Int}}}()
    for k in 1:batch_size:length(idx)
        sel = idx[k:min(k+batch_size-1, length(idx))]
        push!(batches, (X[:, sel], y[sel]))
    end
    return batches
end

function batch_loss(models, x_batch::AbstractMatrix{<:Integer}, y_batch::AbstractVector{<:Integer}, cfg)
    batch_size = size(x_batch, 2)
    low_state, high_state = HRMFlux.init_states(batch_size, cfg.d_hid)
    
    # yhat, _, _ = HRMFlux.run_segment_GRU!(models, x_batch, low_state, high_state; N=cfg.N, T=cfg.T, cfg=cfg)
    yhat, _, _ = HRMFlux.run_sequence_segment!(models, x_batch, low_state, high_state; N=cfg.N, cfg=cfg)

    targets = reshape(Float32.(y_batch), 1, batch_size)
    return Flux.logitbinarycrossentropy(yhat, targets)
end

function accuracy(models, X::AbstractMatrix{<:Integer}, y::AbstractVector{<:Integer}, cfg; batch_size::Int=256)
    correct = 0
    total = 0
    for (xb, yb) in each_minibatch(X, y, batch_size)
        bs = size(xb, 2)
        low_state, high_state = HRMFlux.init_states(bs, cfg.d_hid)

        # yhat, _, _ = HRMFlux.run_segment_GRU!(models, xb, low_state, high_state; N=cfg.N, T=cfg.T, cfg=cfg)
        yhat, _, _ = HRMFlux.run_sequence_segment!(models, xb, low_state, high_state; N=cfg.N, cfg=cfg)
        
        preds = @. yhat > 0  # threshold logit at 0
        correct += sum(Int.(preds[1, :]) .== yb)
        total   += bs
    end
    return correct / total
end

# 5) Train and evaluate
opt_state = Optimisers.setup(Optimisers.Adam(cfg.lr), models)

epochs = 15

for epoch in 1:epochs
    total_loss = 0.0
    batches = 0

    for (xb, yb) in each_minibatch(token_ids_train, y_train, cfg.batch)
        # keep PAD embedding zero
        Zygote.ignore() do
            if hasproperty(models, :tok_emb) && models.tok_emb !== nothing
                models.tok_emb.weight[:, PAD_ID] .= 0f0
            end
        end
        function closure(m)
            batch_loss(m, xb, yb, cfg)
        end
        L, back = Zygote.pullback(closure, models)
        grads = back(one(L))[1]
        opt_state, models = Optimisers.update(opt_state, models, grads)
        total_loss += Float64(L); batches += 1
    end

    acc_tr  = accuracy(models, token_ids_train, y_train, cfg)
    acc_id  = accuracy(models, token_ids_id,   y_test_id, cfg)
    acc_ood = accuracy(models, token_ids_ood,  y_test_ood, cfg)

    @info "epoch=$(epoch)  loss=$(round(total_loss/batches, digits=4))  " *
          "train_acc=$(round(acc_tr, digits=3))  id_acc=$(round(acc_id, digits=3))  ood_acc=$(round(acc_ood, digits=3))"
end

println("Done.")

┌ Info: epoch=1  loss=2.9487  train_acc=0.505  id_acc=0.512  ood_acc=0.558
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:207
┌ Info: epoch=2  loss=0.8009  train_acc=0.548  id_acc=0.496  ood_acc=0.554
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:207
┌ Info: epoch=3  loss=0.6881  train_acc=0.638  id_acc=0.636  ood_acc=0.444
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:207
┌ Info: epoch=4  loss=0.6455  train_acc=0.688  id_acc=0.63  ood_acc=0.52
└ @ Main /home/resort/Documents/repos/JuliaExploreHRM/05_nested_arithmetic_expression_Flux/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:207
┌ Info: epoch=5  loss=0.5188  train_acc=0.726  id_acc=0.68

InterruptException: InterruptException: