In [1]:
using Pkg
Pkg.activate("..")  #one level up, where Project.toml lives
Pkg.instantiate()   #download/install anything missing
# Pkg.status();

[32m[1m  Activating[22m[39m project at `~/Documents/repos/JuliaExploreHRM`


In [19]:
include(joinpath(@__DIR__, "..", "hrm_common.jl"))
using .HRMCommon
using Statistics, Random, Test
using LinearAlgebra: norm
rng = Random.default_rng()

# Y1, st1 = HRMCommon.transformer_forward!(blk, ps, st, X1)
# Y2, st2 = HRMCommon.transformer_forward!(blk, ps, st1, X2)  # reuse states




TaskLocalRNG()

In [20]:
d, L, B = 8, 4, 2
X = randn(Float32, d, L, B)

blk = HRMCommon.make_transformer_block(d; nheads=2, ff_mult=2, pos_kind=:sinusoidal)

ps = (
  pos_layer = Lux.setup(rng, blk.pos_layer)[1],
  mha  = Lux.setup(rng, blk.mha )[1],
  ln1  = Lux.setup(rng, blk.ln1 )[1],
  ff   = Lux.setup(rng, blk.ff  )[1],
  ln2  = Lux.setup(rng, blk.ln2 )[1],
)
st = (
  pos_layer = Lux.setup(rng, blk.pos_layer)[2],
  mha  = Lux.setup(rng, blk.mha )[2],
  ln1  = Lux.setup(rng, blk.ln1 )[2],
  ff   = Lux.setup(rng, blk.ff  )[2],
  ln2  = Lux.setup(rng, blk.ln2 )[2],
)

Y, st2 = HRMCommon.transformer_forward!(blk, ps, st, X)
@assert size(Y) == (d, L, B)


In [26]:

# Small helper to build ps/st safely even when pos_layer is `nothing`
function setup_parameters_and_state(rng, blk)
    if blk.pos_layer === nothing
        pos_ps, pos_st = (;), (;)
    else
        pos_ps, pos_st = Lux.setup(rng, blk.pos_layer)
    end
    mha_ps, mha_st = Lux.setup(rng, blk.mha)
    ln1_ps, ln1_st = Lux.setup(rng, blk.ln1)
    ff_ps,  ff_st  = Lux.setup(rng, blk.ff)
    ln2_ps, ln2_st = Lux.setup(rng, blk.ln2)

    ps = (; pos_layer=pos_ps, mha=mha_ps, ln1=ln1_ps, ff=ff_ps, ln2=ln2_ps)
    st = (; pos_layer=pos_st, mha=mha_st, ln1=ln1_st, ff=ff_st, ln2=ln2_st)
    return ps, st
end

# Mutate-in-place helper to zero all MHA + FF weights (keeps NamedTuples intact)
function zero_mha_and_ff!(ps)
    # Multi-head attention projections
    ps.mha.q_proj.weight .= 0
    ps.mha.k_proj.weight .= 0
    ps.mha.v_proj.weight .= 0
    ps.mha.out_proj.weight .= 0
    # FFN layers
    ps.ff.layer_1.weight .= 0
    ps.ff.layer_1.bias   .= 0
    ps.ff.layer_2.weight .= 0
    ps.ff.layer_2.bias   .= 0
    return ps
end



@testset "HRMCommon - utilities" begin
    d, L, B = 8, 5, 3

    # _as3d
    X2 = randn(Float32, d, B)
    X3 = randn(Float32, d, L, B)
    Y3, was2d = HRMCommon._as3d(X2)
    @test was2d == true
    @test size(Y3) == (d, 1, B)
    Y3b, was2d_b = HRMCommon._as3d(X3)
    @test was2d_b == false
    @test size(Y3b) == size(X3)

    # apply_tokenwise vs explicit batched apply
    ln = Lux.LayerNorm(d)
    ps_ln, st_ln = Lux.setup(rng, ln)

    # 2-D case
    Y2_tok, _ = HRMCommon.apply_tokenwise(ln, ps_ln, st_ln, X2)
    Y2_dir, _  = Lux.apply(ln, X2, ps_ln, st_ln)
    @test size(Y2_tok) == size(X2)
    @test eltype(Y2_tok) == eltype(X2)
    @test Y2_tok ≈ Y2_dir

    # 3-D case
    Y3_tok, _ = HRMCommon.apply_tokenwise(ln, ps_ln, st_ln, X3)
    Y2_dir2, _ = Lux.apply(ln, reshape(X3, d, L*B), ps_ln, st_ln)
    @test size(Y3_tok) == size(X3)
    @test Y3_tok ≈ reshape(Y2_dir2, d, L, B)
end

@testset "HRMCommon - positional encodings (sinusoidal)" begin
    d, L, B = 8, 7, 2
    X1 = randn(Float32, d, L, B)
    X2 = randn(Float32, d, L, B)

    pe, pe_apply = HRMCommon.make_positional_layer(d; kind=:sinusoidal)
    ps_pe, st_pe = Lux.setup(rng, pe)

    Y1, st1 = pe_apply(X1, pe, ps_pe, st_pe)
    Y2, st2 = pe_apply(X2, pe, ps_pe, st1)

    @test size(Y1) == (d, L, B)
    @test eltype(Y1) == Float32

    # The added positional term should depend only on positions (L,B), not on content of X.
    P1 = Y1 .- X1
    P2 = Y2 .- X2
    @test P1 ≈ P2
    @test any(!iszero, P1)  # it really adds something
end

@testset "HRMCommon - transformer block (shapes + residual identity)" begin
    d, L, B = 8, 4, 2
    X = randn(Float32, d, L, B)

    # 1) No positional encodings; with zeroed projections FFN, block should be identity.
    blk_none = HRMCommon.make_transformer_block(d; nheads=2, ff_mult=2,
        attention_dropout_probability=0.0, pos_kind=:none)
    ps_none, st_none = setup_parameters_and_state(rng, blk_none)

    ps_id = deepcopy(ps_none)
    zero_mha_and_ff!(ps_id)

    Y_id, st_id = HRMCommon.transformer_forward!(blk_none, ps_id, st_none, X)
    @test size(Y_id) == size(X)
    @test Y_id ≈ X

    # 2) Sinusoidal PEs; with zeroed projections FFN, output should be X + P (constant shift).
    blk_pe = HRMCommon.make_transformer_block(d; nheads=2, ff_mult=2,
        attention_dropout_probability=0.0, pos_kind=:sinusoidal)
    ps_pe, st_pe = setup_parameters_and_state(rng, blk_pe)

    ps_pe_id = deepcopy(ps_pe)
    zero_mha_and_ff!(ps_pe_id)

    Y_pe1, stp1 = HRMCommon.transformer_forward!(blk_pe, ps_pe_id, st_pe, X)
    P_inferred1 = Y_pe1 .- X

    # change input, same (L,B); inferred P should be identical
    X_alt = randn(Float32, d, L, B)
    Y_pe2, stp2 = HRMCommon.transformer_forward!(blk_pe, ps_pe_id, stp1, X_alt)
    P_inferred2 = Y_pe2 .- X_alt
    @test P_inferred1 ≈ P_inferred2
end

@testset "HRMCommon - gradients" begin
    d, L, B = 8, 6, 3
    X = randn(Float32, d, L, B)
    T = randn(Float32, d, L, B)   # random target for MSE

    blk = HRMCommon.make_transformer_block(d; nheads=2, ff_mult=2,
        attention_dropout_probability=0.0, pos_kind=:sinusoidal)
    ps, st = setup_parameters_and_state(rng, blk)

    # Loss as a function of parameters (state treated as non-differentiable accumulator)
    function loss(ps_blk)
        Y, _ = HRMCommon.transformer_forward!(blk, ps_blk, st, X)
        return sum(abs2, Y .- T) / length(Y)
    end

    g_ps = first(Zygote.gradient(loss, ps))

    # Representative checks: gradients for attention and FFN exist and are nonzero
    @test haskey(g_ps, :mha) && haskey(g_ps.mha, :q_proj)
    @test g_ps.mha.q_proj.weight !== nothing
    @test sum(abs2, g_ps.mha.q_proj.weight) > 0

    @test haskey(g_ps, :ff) && haskey(g_ps.ff, :layer_1)
    @test g_ps.ff.layer_1.weight !== nothing
    @test sum(abs2, g_ps.ff.layer_1.weight) > 0

    # Gradient w.r.t. input X (use a closure over X)
    function loss_X(Xin)
        Y, _ = HRMCommon.transformer_forward!(blk, ps, st, Xin)
        return sum(abs2, Y .- T) / length(Y)
    end
    gX = first(Zygote.gradient(loss_X, X))
    @test size(gX) == size(X)
    @test norm(gX) > 0
end

@testset "HRMCommon - variable sequence length" begin
    d, B = 8, 2
    X1 = randn(Float32, d, 5, B)
    X2 = randn(Float32, d, 11, B)

    blk = HRMCommon.make_transformer_block(d; nheads=2, ff_mult=2,
        attention_dropout_probability=0.0, pos_kind=:sinusoidal)
    ps, st = setup_parameters_and_state(rng, blk)

    Y1, st1 = HRMCommon.transformer_forward!(blk, ps, st, X1)
    @test size(Y1) == size(X1)

    Y2, st2 = HRMCommon.transformer_forward!(blk, ps, st1, X2)
    @test size(Y2) == size(X2)
end

@testset "MHA attention weights sum to 1" begin
    d, L, B = 8, 5, 2
    blk = HRMCommon.make_transformer_block(d; nheads=2, ff_mult=2, pos_kind=:none)
    rng = MersenneTwister(0)
    ps, st = begin
        ps_tmp = (;);
        st_tmp = (;);
        # small helper to setup all parts
        function setup_all(rng, blk)
            if blk.pos_layer === nothing
                pos_ps, pos_st = (;), (;)
            else
                pos_ps, pos_st = Lux.setup(rng, blk.pos_layer)
            end
            mha_ps, mha_st = Lux.setup(rng, blk.mha)
            ln1_ps, ln1_st = Lux.setup(rng, blk.ln1)
            ff_ps,  ff_st  = Lux.setup(rng, blk.ff)
            ln2_ps, ln2_st = Lux.setup(rng, blk.ln2)
            ps = (; pos_layer=pos_ps, mha=mha_ps, ln1=ln1_ps, ff=ff_ps, ln2=ln2_ps)
            st = (; pos_layer=pos_st, mha=mha_st, ln1=ln1_st, ff=ff_st, ln2=ln2_st)
            return ps, st
        end

        setup_all(rng, blk)
    end

    X = randn(Float32, d, L, B)
    Xn1, _ = HRMCommon.apply_tokenwise(blk.ln1, ps.ln1, st.ln1, X)
    (A, attn), _ = Lux.apply(blk.mha, (Xn1, Xn1, Xn1), ps.mha, st.mha)

    sz = size(attn)
    candidates = [ax for ax in 1:length(sz) if sz[ax] == L]

    # For each candidate axis, sum along it and measure how close sums are to 1.
    errors = Float32[]
    for ax in candidates
        summed = sum(attn; dims=ax)
        rows   = dropdims(summed; dims=ax)
        push!(errors, maximum(abs, rows .- 1f0))
    end

    # Require that at least one axis produces ~1 row-sums
    @test minimum(errors) <= 1f-4

end


[0m[1mTest Summary:         | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
HRMCommon - utilities | [32m   9  [39m[36m    9  [39m[0m0.0s
[0m[1mTest Summary:                                 | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
HRMCommon - positional encodings (sinusoidal) | [32m   4  [39m[36m    4  [39m[0m0.0s
[0m[1mTest Summary:                                              | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
HRMCommon - transformer block (shapes + residual identity) | [32m   3  [39m[36m    3  [39m[0m0.0s
[0m[1mTest Summary:         | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
HRMCommon - gradients | [32m   8  [39m[36m    8  [39m[0m0.9s
[0m[1mTest Summary:                        | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
HRMCommon - variable sequence length | [32m   2  [39m[36m    2  [

Test.DefaultTestSet("MHA attention weights sum to 1", Any[], 1, false, false, true, 1.756352106715557e9, 1.756352107069145e9, false, "/home/resort/Documents/repos/JuliaExploreHRM/04_transformer_blocks_common/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W4sZmlsZQ==.jl")

In [3]:
using Plots, Test

In [None]:
#data generator
# y = w^T x + 0.1 * ||x||^2 + eps, where ( eps is N(0, 0.01) )

struct ToyGen
    w::Vector{Float32}
end

function ToyGen(d_in::Int; seed::Int=123)
    Random.seed!(seed)
    w = randn(Float32, d_in)
    return ToyGen(w)
end

#returns (x, y) with shapes (d_in, batch), (1, batch)
function sample!(gen::ToyGen, batch::Int)
    d = length(gen.w)
    x = randn(Float32, d, batch)
    # linear part
    y_lin = gen.w' * x      # (1, batch)
    # small quadratic interaction
    y_quad = 0.1f0 .* sum(abs2, x; dims=1)    # (1, batch)
    # noise
    eps = 0.01f0 .* randn(Float32, 1, batch)
    y = Float32.(y_lin .+ y_quad .+ eps)
    return x, y
end

In [None]:
"""
    quantize_to_tokens(x; num_tokens::Int, lo::Real, hi::Real)

Uniformly bins each element of x in [lo, hi] into 1..num_tokens.
Clamps values outside [lo, hi] to the nearest edge bin.
"""
LO = -3.0
HI =  3.0

function quantize_to_tokens(x; num_tokens::Int, lo::Real, hi::Real)
    if num_tokens ≤ 0
        error("quantize_to_tokens called with num_tokens <= 0; use raw floats instead.")
    end
    @assert num_tokens ≥ 2
    xn = @. clamp((x - lo) / (hi - lo + eps(eltype(x))), 0, 1)
    ids = floor.(Int, xn * (num_tokens - 1)) .+ 1
    return ids
end

In [None]:
CFG = (
    d_in   = 16,
    d_hid  = 64,
    d_out  = 1,
    N      = 2,
    T      = 3,
    M      = 1,
    batch  = 64,
    lr     = 1e-3,
    steps  = 300,
    seed   = 42,

    # input encoding
    num_tokens = 0,     # set >0 to use embeddings with IDs; 0 = raw float encoder
    d_embed    = 32,

    # transformer hyperparameters (shared or separated)
    l_heads    = 2,     # L-module heads
    l_ff_mult  = 4,     # FFN expansion for L
    h_heads    = 2,     # H-module heads
    h_ff_mult  = 4,     # FFN expansion for H
    dropout    = 0.0
)


Random.seed!(42)

In [None]:
# Promote 2-D (d,B) to 3-D (d,1,B); remember if we promoted
_as3d(X) = ndims(X) == 2 ? (reshape(X, size(X,1), 1, size(X,2)), true)  :
          ndims(X) == 3 ? (X, false) :
          error("Expected 2-D or 3-D tensor, got ndims=$(ndims(X))")

# Apply a (d,B) layer tokenwise over (d,L,B)
function apply_tokenwise(layer, ps_layer, st_layer, X)
    X3, was2d = _as3d(X)
    d, L, B = size(X3)
    X2 = reshape(X3, d, L*B)
    Y2, st2 = Lux.apply(layer, X2, ps_layer, st_layer)
    Y3 = reshape(Y2, d, L, B)
    return was2d ? dropdims(Y3; dims=2) : Y3, st2
end

# Apply a (d,B) Chain tokenwise over (d,L,B)
function apply_tokenwise_chain(chain, ps_chain, st_chain, X)
    X3, was2d = _as3d(X)
    d, L, B = size(X3)
    X2 = reshape(X3, d, L*B)
    Y2, st2 = Lux.apply(chain, X2, ps_chain, st_chain)
    Y3 = reshape(Y2, d, L, B)
    return was2d ? dropdims(Y3; dims=2) : Y3, st2
end
