In [27]:
using Pkg
Pkg.activate("..")  #one level up, where Project.toml lives
Pkg.instantiate()   #download/install anything missing
# Pkg.status();

[32m[1m  Activating[22m[39m project at `~/Documents/repos/JuliaExploreHRM`


In [35]:
include(joinpath(@__DIR__, "..", "hrm_common.jl"))
using .HRMCommon
using Statistics, Random, Test
using LinearAlgebra: norm
rng = Random.default_rng()





TaskLocalRNG()

In [37]:
using Random, Lux
using .HRMCommon

rng = MersenneTwister(0)
d, L, B = 8, 4, 2
X = randn(Float32, d, L, B)

blk = HRMCommon.TransformerBlock(d; nheads=2, ff_mult=2, pos_kind=:sinusoidal)
ps, st = Lux.setup(rng, blk)
Y, st2 = Lux.apply(blk, X, ps, st)
@assert size(Y) == (d, L, B)


In [None]:
using Test
using Random, Lux


function _zero_attn_and_ff!(ps)
    ps.mha.q_proj.weight .= 0
    ps.mha.k_proj.weight .= 0
    ps.mha.v_proj.weight .= 0
    ps.mha.out_proj.weight .= 0
    ps.ff.layer.layer_1.weight .= 0
    ps.ff.layer.layer_1.bias   .= 0
    ps.ff.layer.layer_2.weight .= 0
    ps.ff.layer.layer_2.bias   .= 0
    return ps
end

@testset "Tokenwise wrapper = reference reshape" begin
    rng = MersenneTwister(1)
    d, L, B = 8, 7, 3
    X = randn(Float32, d, L, B)

    tw = HRMCommon.Tokenwise(Lux.LayerNorm(d))
    ps_tw, st_tw = Lux.setup(rng, tw)

    # tokenwise path
    Y_tok, _ = Lux.apply(tw, X, ps_tw, st_tw)

    # reference path (manual reshape)
    ps_ln, st_ln = Lux.setup(rng, Lux.LayerNorm(d))
    X2 = reshape(X, d, L*B)
    Y2, _ = Lux.apply(Lux.LayerNorm(d), X2, ps_ln, st_ln)
    Y_ref = reshape(Y2, d, L, B)

    @test size(Y_tok) == size(X)
    @test Y_tok ≈ Y_ref
end

@testset "TransformerBlock identity with pos_kind=:none when projections are zero" begin
    rng = MersenneTwister(2)
    d, L, B = 8, 5, 2
    X = randn(Float32, d, L, B)

    blk = HRMCommon.TransformerBlock(d; nheads=2, ff_mult=2, pos_kind=:none)
    ps, st = Lux.setup(rng, blk)
    _zero_attn_and_ff!(ps)

    Y, _ = Lux.apply(blk, X, ps, st)
    @test Y ≈ X
end

@testset "Sinusoidal PE adds a content-independent shift" begin
    rng = MersenneTwister(3)
    d, L, B = 8, 6, 2
    X1 = randn(Float32, d, L, B)
    X2 = randn(Float32, d, L, B)

    blk = HRMCommon.TransformerBlock(d; nheads=2, ff_mult=2, pos_kind=:sinusoidal)
    ps, st = Lux.setup(rng, blk)
    _zero_attn_and_ff!(ps)  # isolate the positional effect

    Y1, st1 = Lux.apply(blk, X1, ps, st)
    Y2, _   = Lux.apply(blk, X2, ps, st1)

    P1 = Y1 .- X1
    P2 = Y2 .- X2
    @test P1 ≈ P2
    @test any(!iszero, P1)
end

@testset "Gradients flow through MHA and FF" begin
    rng = MersenneTwister(4)
    d, L, B = 8, 7, 3
    X = randn(Float32, d, L, B)
    T = randn(Float32, d, L, B)

    blk = HRMCommon.TransformerBlock(d; nheads=2, ff_mult=2, pos_kind=:sinusoidal)
    ps, st = Lux.setup(rng, blk)

    # loss over parameters
    loss(ps_) = begin
        Y, _ = Lux.apply(blk, X, ps_, st)
        sum(abs2, Y .- T) / length(Y)
    end
    g_ps = first(Zygote.gradient(loss, ps))

    @test sum(abs2, g_ps.mha.q_proj.weight) > 0
    @test sum(abs2, g_ps.ff.layer.layer_1.weight) > 0

    # loss over input
    lossX(X_) = begin
        Y, _ = Lux.apply(blk, X_, ps, st)
        sum(abs2, Y .- T) / length(Y)
    end
    gX = first(Zygote.gradient(lossX, X))
    @test size(gX) == size(X)
    @test sum(abs2, gX) > 0
end


[0m[1mTest Summary:                         | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
Tokenwise wrapper = reference reshape | [32m   2  [39m[36m    2  [39m[0m0.0s
[0m[1mTest Summary:                                                           | [22m

In [3]:
using Plots, Test

In [None]:
#data generator
# y = w^T x + 0.1 * ||x||^2 + eps, where ( eps is N(0, 0.01) )

struct ToyGen
    w::Vector{Float32}
end

function ToyGen(d_in::Int; seed::Int=123)
    Random.seed!(seed)
    w = randn(Float32, d_in)
    return ToyGen(w)
end

#returns (x, y) with shapes (d_in, batch), (1, batch)
function sample!(gen::ToyGen, batch::Int)
    d = length(gen.w)
    x = randn(Float32, d, batch)
    # linear part
    y_lin = gen.w' * x      # (1, batch)
    # small quadratic interaction
    y_quad = 0.1f0 .* sum(abs2, x; dims=1)    # (1, batch)
    # noise
    eps = 0.01f0 .* randn(Float32, 1, batch)
    y = Float32.(y_lin .+ y_quad .+ eps)
    return x, y
end

In [None]:
"""
    quantize_to_tokens(x; num_tokens::Int, lo::Real, hi::Real)

Uniformly bins each element of x in [lo, hi] into 1..num_tokens.
Clamps values outside [lo, hi] to the nearest edge bin.
"""
LO = -3.0
HI =  3.0

function quantize_to_tokens(x; num_tokens::Int, lo::Real, hi::Real)
    if num_tokens ≤ 0
        error("quantize_to_tokens called with num_tokens <= 0; use raw floats instead.")
    end
    @assert num_tokens ≥ 2
    xn = @. clamp((x - lo) / (hi - lo + eps(eltype(x))), 0, 1)
    ids = floor.(Int, xn * (num_tokens - 1)) .+ 1
    return ids
end

In [None]:
CFG = (
    d_in   = 16,
    d_hid  = 64,
    d_out  = 1,
    N      = 2,
    T      = 3,
    M      = 1,
    batch  = 64,
    lr     = 1e-3,
    steps  = 300,
    seed   = 42,

    # input encoding
    num_tokens = 0,     # set >0 to use embeddings with IDs; 0 = raw float encoder
    d_embed    = 32,

    # transformer hyperparameters (shared or separated)
    l_heads    = 2,     # L-module heads
    l_ff_mult  = 4,     # FFN expansion for L
    h_heads    = 2,     # H-module heads
    h_ff_mult  = 4,     # FFN expansion for H
    dropout    = 0.0
)


Random.seed!(42)

In [None]:
# Promote 2-D (d,B) to 3-D (d,1,B); remember if we promoted
_as3d(X) = ndims(X) == 2 ? (reshape(X, size(X,1), 1, size(X,2)), true)  :
          ndims(X) == 3 ? (X, false) :
          error("Expected 2-D or 3-D tensor, got ndims=$(ndims(X))")

# Apply a (d,B) layer tokenwise over (d,L,B)
function apply_tokenwise(layer, ps_layer, st_layer, X)
    X3, was2d = _as3d(X)
    d, L, B = size(X3)
    X2 = reshape(X3, d, L*B)
    Y2, st2 = Lux.apply(layer, X2, ps_layer, st_layer)
    Y3 = reshape(Y2, d, L, B)
    return was2d ? dropdims(Y3; dims=2) : Y3, st2
end

# Apply a (d,B) Chain tokenwise over (d,L,B)
function apply_tokenwise_chain(chain, ps_chain, st_chain, X)
    X3, was2d = _as3d(X)
    d, L, B = size(X3)
    X2 = reshape(X3, d, L*B)
    Y2, st2 = Lux.apply(chain, X2, ps_chain, st_chain)
    Y3 = reshape(Y2, d, L, B)
    return was2d ? dropdims(Y3; dims=2) : Y3, st2
end
