In [3]:
using CSV, DataFrames, DataFramesMeta, Missings
using StatsBase, Statistics, MatrixLM
using Random, Distributions, StatsModels
using LinearAlgebra, PrettyTables
using FreqTables, Plots, StatsPlots

In [4]:
"""
simulate_bilinear_identity_1level

Simulates X (n×p), true B (p×m) with 1-level subclass hierarchy, and Y (n×m)
under Y = X*B + E (equivalently Z = I).

Arguments:
  n, m, p : dims
  H       : number of subclasses
  prop_sub : length-H vector summing to 1 controlling metabolite proportions per subclass
  theta0_scale : prior SD for theta0[k]
  tau_v   : SD across subclasses (between-subclass variability)
  tau_w   : SD across metabolites within subclass
  sigma_y : residual SD for Y

Returns NamedTuple with X, Y, B_true, theta0, beta, subclass_of_met
"""
function simulate_bilinear_identity_1level(;
    n::Int = 98,
    m::Int = 770,
    p::Int = 10,
    H::Int = 4,
    prop_sub = fill(1/H, H),     # proportions of metabolites per subclass
    theta0_scale::Real = 0.2f0,
    tau_v::Real = 0.15f0,
    tau_w::Real = 0.10f0,
    sigma_y::Real = 1.0f0,
    seed::Int = 123,
    T = Float32
)
    @assert length(prop_sub) == H
    @assert abs(sum(prop_sub) - 1) < 1e-6 "prop_sub must sum to 1"

    Random.seed!(seed)

    # ---- cast scalars to T ---------#
    theta0_scaleT = T(theta0_scale)
    tau_vT        = T(tau_v)
    tau_wT        = T(tau_w)
    sigma_yT      = T(sigma_y)

    # ---- X
    X = randn(T, n, p)

    # ---- build subclass_of_met with controlled proportions
    counts = round.(Int, m .* collect(prop_sub))
    # fix rounding so counts sum to m
    counts[end] += m - sum(counts)

    subclass_of_met = Vector{Int}(undef, m)
    idx = 1
    for h in 1:H
        for _ in 1:counts[h]
            subclass_of_met[idx] = h
            idx += 1
        end
    end
    # shuffle metabolites so subclasses are mixed
    perm = randperm(m)
    subclass_of_met = subclass_of_met[perm]

    # ---- true hierarchical coefficients
    theta0 = rand.(Normal(T(0), theta0_scaleT), p)              # length p

    beta = Matrix{T}(undef, H, p)                               # H × p
    for h in 1:H, k in 1:p
        beta[h, k] = rand(Normal(theta0[k], tau_vT))
    end

    theta = Matrix{T}(undef, m, p)                              # m × p (met × cov)
    for j in 1:m
        h = subclass_of_met[j]
        for k in 1:p
            theta[j, k] = rand(Normal(beta[h, k], tau_wT))
        end
    end

    # B_true is p×m, with B[k,j] = theta[j,k]
    B_true = permutedims(theta)                                 # p × m

    # ---- generate Y = X*B + E
    Y_mean = X * B_true                                         # n × m
    Y = Y_mean .+ randn(T, n, m) .* sigma_yT

    return (X=X, Y=Y, B_true=B_true, theta0=theta0, beta=beta,
            subclass_of_met=subclass_of_met, perm=perm)
end

simulate_bilinear_identity_1level

In [5]:
"""
fit_matrixlm_identity(X, Y)

Returns:
B_hat (p×m), SE_hat (p×m), tstat (p×m), fitobj
"""
function fit_matrixlm_identity(X::AbstractMatrix, Y::AbstractMatrix)
    n, p = size(X)
    n2, m = size(Y)
    @assert n == n2

    Z = Diagonal(ones(m))  # identity without forming dense m×m

    fit = mlm(RawData(Response(Y), Predictors(X, Z)), addXIntercept=false, addZIntercept=false)
    B_hat = MatrixLM.coef(fit)
    tstat = MatrixLM.t_stat(fit)

    # SE = |coef|/|t|
    SE_hat = similar(B_hat)
    @inbounds for k in 1:size(B_hat,1), j in 1:size(B_hat,2)
        denom = max(abs(tstat[k,j]), eps(Float64))
        SE_hat[k,j] = abs(B_hat[k,j]) / denom
    end
    return B_hat, SE_hat, tstat, fit
end

fit_matrixlm_identity

In [13]:
sim = simulate_bilinear_identity_1level(n=98, m=770, p=6, H=4,
                                        prop_sub=[0.25, 0.25, 0.25, 0.25], # balanced proportions
                                        theta0_scale=0.2f0,
                                        tau_v=0.12f0, tau_w=0.08f0,
                                        sigma_y=1.0f0, seed=12, T=Float32)

X = sim.X
Y = sim.Y
B_true = sim.B_true
subclass_of_met = sim.subclass_of_met
H = length(unique(subclass_of_met))

4

In [14]:
# center Y matrix
Yc = Y .- mean(Y, dims=1)

98×770 Matrix{Float32}:
  1.97105     0.581971     1.22157    …   1.31865    0.959001    0.0704058
 -0.397028    1.9378       1.27785        0.66921   -2.23077    -0.123867
 -0.695264   -1.43451      0.562154      -0.699391   1.07866    -0.411321
 -0.339796   -1.88825     -1.0929         0.785635  -1.2949      1.30891
  0.48478     0.00929036  -0.915866       0.810651  -0.687755    0.804492
  0.571123   -0.450447    -1.15496    …   0.691421  -0.796221   -0.560865
  1.08879     0.782309     0.908724       0.989689  -2.90473    -0.952254
 -0.662977   -1.82188      0.0953922      0.647613   1.27111     1.33439
 -0.270724    0.68638      1.53218        1.81858    1.44593    -2.21134
  0.0176364  -0.665557     0.879218       1.2925    -0.233305   -1.49674
 -2.22917    -0.0596513    0.639729   …   0.247535  -1.25877    -0.992649
 -0.0548428   1.01829     -0.492791       1.6083     0.727312   -0.140394
  1.40984    -0.214576    -1.0961        -0.195066  -1.75523     0.505993
  ⋮              

In [15]:
"""
Split rows of Y (n×m) and X (n×p) into train/test.

Returns:
  (Y_tr, X_tr, Y_te, X_te, idx_tr, idx_te)
"""
function train_test_split_rows(Y, X; train_frac=0.70, seed=1234, shuffle=true)
    n = size(Y, 1)
    @assert size(X, 1) == n "X and Y must have the same number of rows (individuals)."

    rng = MersenneTwister(seed)
    idx = collect(1:n)
    if shuffle
        Random.shuffle!(rng, idx)
    end

    n_tr = floor(Int, train_frac * n)
    idx_tr = idx[1:n_tr]
    idx_te = idx[n_tr+1:end]

    Y_tr = @view Y[idx_tr, :]
    X_tr = @view X[idx_tr, :]
    Y_te = @view Y[idx_te, :]
    X_te = @view X[idx_te, :]

    return (Y_tr, X_tr, Y_te, X_te, idx_tr, idx_te)
end

# ---- run split COPD Data ----
Y_tr, X_tr, Y_te, X_te, idx_tr, idx_te =
    train_test_split_rows(Yc, X; train_frac=0.70, seed=5)

@show size(Y_tr) size(X_tr) size(Y_te) size(X_te)
@show length(idx_tr) length(idx_te)

size(Y_tr) = (68, 770)
size(X_tr) = (68, 6)
size(Y_te) = (30, 770)
size(X_te) = (30, 6)
length(idx_tr) = 68
length(idx_te) = 30


30

In [16]:
# MatrixLM
B_hat, SE_hat, tstat, fit = fit_matrixlm_identity(X_tr, Y_tr)

([0.08477096619157676 -0.17594671631079148 … 0.07822600297850994 0.1054209006299151; 0.429454236098947 0.15655513133216947 … 0.31322891780407286 -0.12606926829675863; … ; 0.08071930026558433 0.015480202738743808 … 0.15981794606999192 0.2675223987505882; 0.0967470278097964 -0.034737271163372774 … 0.22918927557836233 0.07751253849210515], [0.11442562068051199 0.11181458816105178 … 0.11913150441946382 0.10246605888699424; 0.11919120315238922 0.11647142671059556 … 0.1240930768884035 0.10673355117840695; … ; 0.12351776106708402 0.1206992586289146 … 0.12859756941612735 0.1106079049763463; 0.13252536706600276 0.12950132366554024 … 0.1379756234523397 0.11867405206147412], [0.7408390331415894 -1.5735577906647316 … 0.6566357351039152 1.0288372732885105; 3.6030698972799025 1.344150541927958 … 2.5241449858299396 -1.1811587537833506; … ; 0.6535035898338918 0.12825433158903748 … 1.2427757911414246 2.418655328548157; 0.7300264843757243 -0.26823873440157137 … 1.6610852688593223 0.6531548990334722], Ml

In [12]:
# Gibbs Sampler for 1-level hierarchical model

function gibbs_meta_hier_traces_db(
    b_obs::Vector{Float64},           # length m
    se_obs::Vector{Float64},          # length m
    subclass_of_met::Vector{Int},     # length m, values in 1..H (Total subclass)
    H::Int;                           # number of subclasses
    mu0::Float64 = 0.0,
    s0::Float64 = 1.0,
    halfcauchy_scale::Float64 = 1.0,
    n_iter::Int = 2000,
    burnin::Int = 500,
    thin::Int = 1,
    seed::Int = 1234,
)
    @assert length(b_obs) == length(se_obs) "b_obs and se_obs must have same length"
    @assert length(subclass_of_met) == length(b_obs) "subclass_of_met must have length m"
    @assert maximum(subclass_of_met) == H "subclass_of_met must be in 1..H"
    @assert burnin < n_iter "burnin must be < n_iter"
    @assert thin ≥ 1

    Random.seed!(seed)
    m = length(b_obs)
    lam_j = 1.0 ./ (se_obs .^ 2)  # observation precisions

    # --- initialize state ---
    theta = copy(b_obs)                                     # length m (metabolite-level effects)
    beta  = [mean(theta[subclass_of_met .== h]) for h in 1:H]  # class means
    theta0 = mean(beta)                                     # global mean over classes

    tau_w2 = 1.0   # within-class variance for theta_j | beta_h
    tau_v2 = 1.0   # between-class variance for beta_h | theta0

    lambda_w = 1.0    # IG mixture auxiliaries for half-Cauchy
    lambda_v = 1.0

    # --- precompute group indices ---
    idx_by_sub = [findall(==(h), subclass_of_met) for h in 1:H]

    # --- storage sizes ---
    n_keep = floor(Int, (n_iter - burnin) ÷ thin)

    # scalars per draw (for quick monitoring)
    draws_scalar = DataFrame(
        theta0 = Vector{Float64}(undef, n_keep),
        tau_w2 = Vector{Float64}(undef, n_keep),
        tau_v2 = Vector{Float64}(undef, n_keep),
    )

    # vectors per draw
    beta_draws  = Array{Float64}(undef, H, n_keep)   # columns = kept iters
    theta_draws = Array{Float64}(undef, m, n_keep)

    keep_idx = 0

    for it in 1:n_iter
        #######################
        # 1) update theta_j   #
        #######################
        for j in 1:m
            h = subclass_of_met[j]
            lik_prec   = lam_j[j]
            lik_sum    = lam_j[j] * b_obs[j]
            prior_prec = 1.0 / tau_w2
            prior_mean = beta[h]
            mean_th = (lik_sum + prior_prec * prior_mean) / (lik_prec + prior_prec)
            var_th  = 1.0 / (lik_prec + prior_prec)
            theta[j] = rand(Normal(mean_th, sqrt(var_th)))
        end

        #######################
        # 2) update beta_h    #
        #######################
        for h in 1:H
            J = idx_by_sub[h]
            mh = length(J)

            # likelihood: theta_j | beta_h ~ N(beta_h, tau_w2)
            lik_prec   = mh / tau_w2
            lik_sum    = (1.0 / tau_w2) * sum(theta[J])

            # prior: beta_h | theta0 ~ N(theta0, tau_v2)
            prior_prec = 1.0 / tau_v2
            prior_mean = theta0

            mean_b = (lik_sum + prior_prec * prior_mean) / (lik_prec + prior_prec)
            var_b  = 1.0 / (lik_prec + prior_prec)
            beta[h] = rand(Normal(mean_b, sqrt(var_b)))
        end

        #######################
        # 3) update theta0    #
        #######################
        # beta_h | theta0 ~ N(theta0, tau_v2)
        lik_prec   = H / tau_v2
        lik_sum    = (1.0 / tau_v2) * sum(beta)

        # prior: theta0 ~ N(mu0, s0^2)
        prior_prec = 1.0 / (s0^2)
        prior_mean = mu0

        mean_t0 = (lik_sum + prior_prec * prior_mean) / (lik_prec + prior_prec)
        var_t0  = 1.0 / (lik_prec + prior_prec)
        theta0  = rand(Normal(mean_t0, sqrt(var_t0)))

        ############################################
        # 4) update tau_w2 and tau_v2 via IG mix   #
        ############################################

        # tau_w2: within-class spread of theta_j around beta_{class(j)}
        ssw = sum((theta .- beta[subclass_of_met]).^2)
        tau_w2   = rand(InverseGamma((m + 1.0) / 2.0, 0.5 * ssw + 1.0 / lambda_w))
        lambda_w = rand(InverseGamma(1.0, 1.0 / (halfcauchy_scale^2) + 1.0 / tau_w2))

        # tau_v2: between-class spread of beta_h around theta0
        ssv = sum((beta .- theta0).^2)
        tau_v2   = rand(InverseGamma((H + 1.0) / 2.0, 0.5 * ssv + 1.0 / lambda_v))
        lambda_v = rand(InverseGamma(1.0, 1.0 / (halfcauchy_scale^2) + 1.0 / tau_v2))

        ###################################
        # 5) store post-burn (with thin) #
        ###################################
        if it > burnin && ((it - burnin) % thin == 0)
            keep_idx += 1
            draws_scalar.theta0[keep_idx] = theta0
            draws_scalar.tau_w2[keep_idx] = tau_w2
            draws_scalar.tau_v2[keep_idx] = tau_v2

            beta_draws[:, keep_idx]  .= beta
            theta_draws[:, keep_idx] .= theta
        end
    end

    return (; draws_scalar,
            beta_draws,
            theta_draws,
            last_state = (; theta0, tau_w2, tau_v2,
                          beta = copy(beta), theta = copy(theta)))
end

gibbs_meta_hier_traces_db (generic function with 1 method)

In [17]:
"""
Given theta_draws (m × n_keep), return:
  mean_theta :: Vector{Float64} length m
  sd_theta   :: Vector{Float64} length m
"""
function summarize_theta_draws(theta_draws::AbstractMatrix{<:Real})
    m, n_keep = size(theta_draws)
    mean_theta = vec(mean(theta_draws; dims=2))
    sd_theta   = vec(std(theta_draws; dims=2, corrected=true))
    return mean_theta, sd_theta
end

summarize_theta_draws

In [18]:
"""
Run Gibbs meta-hierarchy for each covariate (row of B_obs/SE_obs).

Returns:
  B_bayes  :: Matrix{Float64}  (p×m) posterior means
  SE_bayes :: Matrix{Float64}  (p×m) posterior SDs   (NOT frequentist SEs)
  res_list :: Vector           results per covariate (optional to keep)
"""
function fit_bayes_all_covariates(
    B_obs::AbstractMatrix{<:Real},          # p×m
    SE_obs::AbstractMatrix{<:Real},         # p×m
    subclass_of_met::Vector{Int},
    H::Int;
    mu0::Float64 = 0.0,
    s0::Float64 = 1.0,
    halfcauchy_scale::Float64 = 1.0,
    n_iter::Int = 5000,
    burnin::Int = 1000,
    thin::Int = 1,
    seed0::Int = 42,
    keep_results::Bool = false
)
    p, m = size(B_obs)
    @assert size(SE_obs) == (p, m) "SE_obs must have same shape as B_obs"

    B_bayes  = Array{Float64}(undef, p, m)
    SE_bayes = Array{Float64}(undef, p, m)

    res_list = keep_results ? Vector{Any}(undef, p) : Any[]

    for k in 1:p
        b_vec  = vec(Float64.(B_obs[k, :]))
        se_vec = vec(Float64.(SE_obs[k, :]))

        #mu0 = mu0_mode === :empirical ? mean(b_vec) : 0.0

        res_k = gibbs_meta_hier_traces_db(
            b_vec, se_vec, subclass_of_met, H;
            mu0 = mu0, s0 = s0,
            halfcauchy_scale = halfcauchy_scale,
            n_iter = n_iter, burnin = burnin, thin = thin,
            seed = seed0 + k
        )

        @assert res_k.theta_draws !== nothing "theta_draws missing; ensure save_theta=true"

        mean_theta, sd_theta = summarize_theta_draws(res_k.theta_draws)

        # store into row k
        @inbounds begin
            B_bayes[k, :]  .= mean_theta
            SE_bayes[k, :] .= sd_theta
        end

        if keep_results
            res_list[k] = res_k
        end

        println("Done covariate k=$k / $p")
    end

    return B_bayes, SE_bayes, res_list
end

fit_bayes_all_covariates

In [19]:
B_bayes_tr, SE_bayes_tr, res_by_cov =
    fit_bayes_all_covariates(
        B_hat, SE_hat,
        subclass_of_met, H;
        mu0 = 0.0,      
        s0 = 1.0,
        n_iter = 5000, burnin = 1000, thin = 1,
        seed0 = 1000,
        keep_results = false
    )

@show size(B_bayes_tr) size(SE_bayes_tr)

Done covariate k=1 / 6
Done covariate k=2 / 6
Done covariate k=3 / 6
Done covariate k=4 / 6
Done covariate k=5 / 6
Done covariate k=6 / 6
size(B_bayes_tr) = (6, 770)
size(SE_bayes_tr) = (6, 770)


(6, 770)

In [20]:
mse_coef_mlm  = mean((B_hat .- B_true).^2)
mse_coef_bays = mean((B_bayes_tr .- B_true).^2)

println("Coef MSE (MatrixLM): ", mse_coef_mlm)
println("Coef MSE (Bayes):    ", mse_coef_bays)

Coef MSE (MatrixLM): 0.015255531612656952
Coef MSE (Bayes):    0.004625125020177808


In [21]:
"""
Predict Yhat and compute MSE summaries.

Inputs
  Y_te :: (n_te × m) matrix
  X_te :: (n_te × p) matrix
  B    :: (p × m) coefficient matrix

Returns
  mse_all :: Float64
  mse_met :: Vector{Float64} length m   (MSE per metabolite)
  mse_ind :: Vector{Float64} length n_te (MSE per individual)
"""
function test_mse(Y_te::AbstractMatrix{<:Real},
                  X_te::AbstractMatrix{<:Real},
                  B::AbstractMatrix{<:Real})

    n_te, m = size(Y_te)
    @assert size(X_te, 1) == n_te "X_te and Y_te must have same number of rows"
    p = size(X_te, 2)
    @assert size(B) == (p, m) "B must be p×m with p=size(X_te,2), m=size(Y_te,2)"

    # predictions
    Yhat = X_te * B                       # n_te × m
    R = Y_te .- Yhat                      # residuals

    mse_all = mean(abs2, R)               # overall mean squared error
    mse_met = vec(mean(abs2, R; dims=1))  # average over individuals -> per metabolite
    mse_ind = vec(mean(abs2, R; dims=2))  # average over metabolites -> per individual

    return mse_all, mse_met, mse_ind
end

test_mse

In [22]:
# Ensure Float64 (optional but avoids mixed types)
Y_te_f = Float64.(Y_te)
X_te_f = Float64.(X_te)

B_mlm  = Float64.(B_hat)
B_bys  = Float64.(B_bayes_tr)

mse_mlm, mse_mlm_met, mse_mlm_ind = test_mse(Y_te_f, X_te_f, B_mlm)
mse_bys, mse_bys_met, mse_bys_ind = test_mse(Y_te_f, X_te_f, B_bys)

println("Test MSE (MatrixLM): ", mse_mlm)
println("Test MSE (Bayes):    ", mse_bys)
println("Relative improvement (positive is better): ",
        (mse_mlm - mse_bys) / mse_mlm)

Test MSE (MatrixLM): 1.0755549234189303
Test MSE (Bayes):    1.0049701891294378
Relative improvement (positive is better): 0.06562634111247488


In [23]:
# Coef MSE vs truth (both p x m)
coef_mse(B_est::AbstractMatrix, B_true::AbstractMatrix) = mean((B_est .- B_true) .^ 2)

# Test MSE on Yhat = X * B (Y is n x m, X is n x p, B is p x m)
test_mse(Y::AbstractMatrix, X::AbstractMatrix, B::AbstractMatrix) = mean((Y .- (X * B)) .^ 2)

test_mse (generic function with 2 methods)

In [24]:
# Simulation study comparing MatrixLM vs Bayesian hierarchical model (one-level hierarchy) across 10 random datasets

function run_simstudy_10(;
    nrep::Int = 10,
    sim_kwargs = (; n=98, m=770, p=10, H=4,
                   prop_sub=[0.55, 0.25, 0.15, 0.05],
                   theta0_scale=0.2f0, tau_v=0.12f0, tau_w=0.08f0,
                   sigma_y=1.0f0, T=Float32),
    split_frac::Float64 = 0.70,
    seed_base::Int = 1,

    # Bayes settings
    bayes_s0::Float64 = 1.0,
    bayes_halfcauchy_scale::Float64 = 1.0,
    bayes_n_iter::Int = 2000,
    bayes_burnin::Int = 100,
    bayes_thin::Int = 1,
    bayes_mu0::Float64 = 0.0
)
    rows = Vector{NamedTuple}()

    for r in 1:nrep
        sim_seed = seed_base + 10_000 * r

        # (1) simulate
        sim = simulate_bilinear_identity_1level(; seed=sim_seed, sim_kwargs...)
        X = sim.X
        Y = sim.Y
        B_true = sim.B_true
        subclass_of_met = sim.subclass_of_met
        H = sim_kwargs.H

        # (2) split
        (Y_tr, X_tr, Y_te, X_te, _, _) = train_test_split_rows(Y, X; train_frac=split_frac, seed=sim_seed + 777)

        # (3) fit MatrixLM on train
        (B_hat, SE_hat, tstat, fitobj) = fit_matrixlm_identity(X_tr, Y_tr)

        # (4) Bayes fit on train (meta-hierarchy per covariate)
        (B_bayes, SE_bayes, res_by_cov) = fit_bayes_all_covariates(B_hat, SE_hat, subclass_of_met, H;
                                            mu0 = bayes_mu0,
                                            s0 = bayes_s0,
                                            halfcauchy_scale = bayes_halfcauchy_scale,
                                            n_iter = bayes_n_iter,
                                            burnin = bayes_burnin,
                                            thin = bayes_thin,
                                            seed0=sim_seed + 2000,
                                            keep_results=false)
        

        # (5) coef MSE vs truth
        cmse_mlm  = coef_mse(Float64.(B_hat),  Float64.(B_true))
        cmse_bys  = coef_mse(Float64.(B_bayes), Float64.(B_true))

        # (6) test MSE on held-out rows
        tmse_mlm  = test_mse(Float64.(Y_te), Float64.(X_te), Float64.(B_hat))
        tmse_bys  = test_mse(Float64.(Y_te), Float64.(X_te), Float64.(B_bayes))
        tmse_mlm_scalar = tmse_mlm[1]
        tmse_bys_scalar = tmse_bys[1]

        @show typeof(cmse_mlm) typeof(cmse_bys) typeof(tmse_mlm) typeof(tmse_bys)

        push!(rows, (rep=r,
                        coef_mse_mlm=cmse_mlm, coef_mse_bayes=cmse_bys,
                        test_mse_mlm=tmse_mlm_scalar, test_mse_bayes=tmse_bys_scalar))
        println("rep=$r  coef_mse: mlm=$(cmse_mlm)  bayes=$(cmse_bys)  test_mse: mlm=$(tmse_mlm_scalar)  bayes=$(tmse_bys_scalar)")
    end

    df = DataFrame(rows)

    avg = (coef_mse_mlm = mean(df.coef_mse_mlm),
           coef_mse_bayes = mean(df.coef_mse_bayes),
           test_mse_mlm = mean(df.test_mse_mlm),
           test_mse_bayes = mean(df.test_mse_bayes))

    return (per_rep=df, avg=avg)
end

run_simstudy_10 (generic function with 1 method)

In [25]:
out = run_simstudy_10(nrep=10)

Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10
Done covariate k=10 / 10
typeof(cmse_mlm) = Float64
typeof(cmse_bys) = Float64
typeof(tmse_mlm) = Tuple{Float64, Vector{Float64}, Vector{Float64}}
typeof(tmse_bys) = Tuple{Float64, Vector{Float64}, Vector{Float64}}
rep=1  coef_mse: mlm=0.016069176035020147  bayes=0.004774836852793692  test_mse: mlm=1.1911520278998815  bayes=1.0655373908025807
Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10
Done covariate k=10 / 10
typeof(cmse_mlm) = Float64
typeof(cmse_bys) = Float64
typeof(tmse_mlm) = Tuple{Float64, Vector{Float64}, Vector{Float64}}
typeof(tmse_bys) = Tuple{Float64, Vector{Float64}, Vector{Float64}}
rep=2  c

(per_rep = [1m10×5 DataFrame[0m
[1m Row [0m│[1m rep   [0m[1m coef_mse_mlm [0m[1m coef_mse_bayes [0m[1m test_mse_mlm [0m[1m test_mse_bayes [0m
     │[90m Int64 [0m[90m Float64      [0m[90m Float64        [0m[90m Float64      [0m[90m Float64        [0m
─────┼───────────────────────────────────────────────────────────────────
   1 │     1     0.0160692      0.00477484       1.19115         1.06554
   2 │     2     0.016122       0.00504205       1.17829         1.05531
   3 │     3     0.0181979      0.00527192       1.19439         1.05743
   4 │     4     0.0176656      0.00514767       1.15926         1.04581
   5 │     5     0.0158825      0.00499128       1.16957         1.06923
   6 │     6     0.01728        0.00491266       1.16133         1.03732
   7 │     7     0.0177615      0.00522695       1.17777         1.05179
   8 │     8     0.0166239      0.0051621        1.15599         1.0459
   9 │     9     0.0173684      0.00497272       1.21355         1.0

In [26]:
# Simulation study comparing MatrixLM vs Bayesian hierarchical model (one-level hierarchy)
# across a GRID of (n,m) settings; each setting repeated nrep times.
#
# Output:
#   per_rep : one row per replicate per (n,m)
#   summary : mean Coef MSE / Test MSE per (n,m)

function run_simstudy_grid_1level(;
    nrep::Int = 10,

    # 10 datasets via 10 (n,m) pairs (edit these if you want)
    nm_grid = [
        (n=60,  m=300),
        (n=60,  m=770),
        (n=60,  m=1200),
        (n=98,  m=300),
        (n=98,  m=770),
        (n=98,  m=1200),
        (n=200, m=300),
        (n=200, m=770),
        (n=200, m=1200),
        (n=400, m=770),
    ],

    # use same settings as first simulation
    sim_base_kwargs = (; p=10, H=4,
        prop_sub=[0.55, 0.25, 0.15, 0.05],
        theta0_scale=0.2f0, tau_v=0.12f0, tau_w=0.08f0,
        sigma_y=1.0f0, T=Float32
    ),

    split_frac::Float64 = 0.70,
    seed_base::Int = 1,

    # Bayes settings
    bayes_s0::Float64 = 1.0,
    bayes_halfcauchy_scale::Float64 = 1.0,
    bayes_n_iter::Int = 2000,
    bayes_burnin::Int = 100,
    bayes_thin::Int = 1,
    bayes_mu0::Float64 = 0.0
)

    rows = Vector{NamedTuple}()

    # loop over (n,m) scenarios
    for (sid, nm) in enumerate(nm_grid)
        n = nm.n
        m = nm.m

        # build sim kwargs for this scenario
        sim_kwargs = merge(sim_base_kwargs, (; n=n, m=m))

        # repeat each scenario nrep times
        for r in 1:nrep
            sim_seed = seed_base + 1_000_000*sid + 10_000*r

            # (1) simulate
            sim = simulate_bilinear_identity_1level(; seed=sim_seed, sim_kwargs...)
            X = sim.X
            Y = sim.Y
            B_true = sim.B_true
            subclass_of_met = sim.subclass_of_met
            H = sim_kwargs.H

            # (2) split
            (Y_tr, X_tr, Y_te, X_te, _, _) =
                train_test_split_rows(Y, X; train_frac=split_frac, seed=sim_seed + 777)

            # (3) fit MatrixLM on train
            (B_hat, SE_hat, tstat, fitobj) = fit_matrixlm_identity(X_tr, Y_tr)

            # (4) Bayes fit on train
            (B_bayes, SE_bayes, res_by_cov) =
                fit_bayes_all_covariates(B_hat, SE_hat, subclass_of_met, H;
                    mu0 = bayes_mu0,
                    s0 = bayes_s0,
                    halfcauchy_scale = bayes_halfcauchy_scale,
                    n_iter = bayes_n_iter,
                    burnin = bayes_burnin,
                    thin = bayes_thin,
                    seed0 = sim_seed + 2000,
                    keep_results = false
                )

            # (5) coef MSE vs truth
            cmse_mlm = coef_mse(Float64.(B_hat),  Float64.(B_true))
            cmse_bys = coef_mse(Float64.(B_bayes), Float64.(B_true))

            # (6) test MSE on held-out rows (test_mse returns (scalar, vec, vec))
            tmse_mlm = test_mse(Float64.(Y_te), Float64.(X_te), Float64.(B_hat))
            tmse_bys = test_mse(Float64.(Y_te), Float64.(X_te), Float64.(B_bayes))
            tmse_mlm_scalar = tmse_mlm[1]
            tmse_bys_scalar = tmse_bys[1]

            push!(rows, (
                scenario = sid,
                rep = r,
                n = n,
                m = m,
                coef_mse_mlm = cmse_mlm,
                coef_mse_bayes = cmse_bys,
                test_mse_mlm = tmse_mlm_scalar,
                test_mse_bayes = tmse_bys_scalar
            ))

            println("scenario=$sid (n=$n,m=$m) rep=$r  coef_mse: mlm=$(cmse_mlm) bayes=$(cmse_bys)  test_mse: mlm=$(tmse_mlm_scalar) bayes=$(tmse_bys_scalar)")
        end
    end

    per_rep = DataFrame(rows)

    # Summary table: averages per (n,m)
    summary = combine(groupby(per_rep, [:n, :m]),
        :coef_mse_mlm   => mean => :coef_mse_mlm_mean,
        :coef_mse_bayes => mean => :coef_mse_bayes_mean,
        :test_mse_mlm   => mean => :test_mse_mlm_mean,
        :test_mse_bayes => mean => :test_mse_bayes_mean,
    )

    # Calculate relative improvements
    summary.relimp_coef_mse_mean =
        (summary.coef_mse_mlm_mean .- summary.coef_mse_bayes_mean) ./ summary.coef_mse_mlm_mean
    summary.relimp_test_mse_mean = (summary.test_mse_mlm_mean .- summary.test_mse_bayes_mean) ./ summary.test_mse_mlm_mean

    # Calculate deltas (Bayes - MLM)
    #summary.delta_coef_mse_mean = summary.coef_mse_bayes_mean .- summary.coef_mse_mlm_mean
    #summary.delta_test_mse_mean = summary.test_mse_bayes_mean .- summary.test_mse_mlm_mean

    return (per_rep=per_rep, summary=summary)
end

run_simstudy_grid_1level (generic function with 1 method)

In [27]:
sim_results_1 = run_simstudy_grid_1level(nrep=10)

Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10
Done covariate k=10 / 10
scenario=1 (n=60,m=300) rep=1  coef_mse: mlm=0.031653504369724014 bayes=0.007412082906699721  test_mse: mlm=1.3219588107737774 bayes=1.0671892015508595
Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10
Done covariate k=10 / 10
scenario=1 (n=60,m=300) rep=2  coef_mse: mlm=0.030759089369564795 bayes=0.007696223702773454  test_mse: mlm=1.3037829992063268 bayes=1.066607335950528
Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10

(per_rep = [1m100×8 DataFrame[0m
[1m Row [0m│[1m scenario [0m[1m rep   [0m[1m n     [0m[1m m     [0m[1m coef_mse_mlm [0m[1m coef_mse_bayes [0m[1m test_mse_m[0m ⋯
     │[90m Int64    [0m[90m Int64 [0m[90m Int64 [0m[90m Int64 [0m[90m Float64      [0m[90m Float64        [0m[90m Float64   [0m ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │        1      1     60    300    0.0316535       0.00741208       1.321 ⋯
   2 │        1      2     60    300    0.0307591       0.00769622       1.303
   3 │        1      3     60    300    0.0337424       0.00837021       1.289
   4 │        1      4     60    300    0.0328124       0.00749406       1.310
   5 │        1      5     60    300    0.041053        0.00895717       1.320 ⋯
   6 │        1      6     60    300    0.0318049       0.00871622       1.265
   7 │        1      7     60    300    0.0331196       0.00798991       1.354
   8 │        1      8     60    300  

In [28]:
sim_results_1.summary

Row,n,m,coef_mse_mlm_mean,coef_mse_bayes_mean,test_mse_mlm_mean,test_mse_bayes_mean,relimp_coef_mse_mean,relimp_test_mse_mean
Unnamed: 0_level_1,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Float64
1,60,300,0.0330105,0.00788443,1.31266,1.07294,0.761154,0.182619
2,60,770,0.0320911,0.00749922,1.3272,1.0735,0.766315,0.191157
3,60,1200,0.0328411,0.00750844,1.31952,1.07283,0.771371,0.186951
4,98,300,0.0175266,0.00535771,1.18294,1.0556,0.69431,0.107647
5,98,770,0.0174957,0.00515362,1.17148,1.04807,0.705435,0.105349
6,98,1200,0.0176252,0.00508787,1.17156,1.0492,0.711331,0.104441
7,200,300,0.00761958,0.00359791,1.07195,1.0337,0.527808,0.0356838
8,200,770,0.00772324,0.00357037,1.08107,1.03835,0.537711,0.0395188
9,200,1200,0.00783407,0.00356144,1.07894,1.03647,0.545391,0.0393586
10,400,770,0.00369911,0.00235635,1.03642,1.02337,0.362995,0.0125941


In [29]:
# Two-heterogeneity simulation study:
# Same (n,m) grid, but two settings for (tau_v, tau_w) (moderate vs large).

function run_simstudy_grid_1level_2hetero(;
    nrep::Int = 10,

    # same (n,m) grid for both heterogeneity regimes
    nm_grid = [
        (n=60,  m=300),
        (n=60,  m=770),
        #(n=60,  m=1200),
        #(n=98,  m=300),
        #(n=98,  m=770),
        #(n=98,  m=1200),
        (n=200, m=300),
        (n=200, m=770),
        #(n=200, m=1200),
        (n=400, m=770),
    ],

    # Base sim params (shared unless overridden per regime)
    sim_base_kwargs = (; p=10, H=4,
        prop_sub=[0.55, 0.25, 0.15, 0.05],
        theta0_scale=0.2f0,
        sigma_y=1.0f0, T=Float32
    ),

    # Two regimes of heterogeneity
    # Interpretation:
    #   tau_v = between-subclass SD (how different subclass means are)
    #   tau_w = within-subclass SD (how noisy metabolite effects are around subclass mean)
    #
    # "Large heterogeneity" usually means larger tau_v (and/or tau_w)
    # tau_v moderate-to-large, tau_w smaller (tight within-subclass clustering).
    hetero_regimes = [
        (name="moderate", tau_v=0.12f0, tau_w=0.08f0),
        (name="large",    tau_v=0.30f0, tau_w=0.12f0),
    ],

    split_frac::Float64 = 0.70,
    seed_base::Int = 1,

    # Bayes settings
    bayes_s0::Float64 = 1.0,
    bayes_halfcauchy_scale::Float64 = 1.0,
    bayes_n_iter::Int = 2000,
    bayes_burnin::Int = 100,
    bayes_thin::Int = 1,
    bayes_mu0::Float64 = 0.0
)

    rows = Vector{NamedTuple}()

    for (rid, reg) in enumerate(hetero_regimes)
        reg_name = reg.name

        # Per-regime heterogeneity params
        tau_v = reg.tau_v
        tau_w = reg.tau_w

        # Possible H override
        H_reg = hasproperty(reg, :H) ? getfield(reg, :H) : sim_base_kwargs.H

        # Possible prop_sub override
        prop_sub_reg = hasproperty(reg, :prop_sub) ? getfield(reg, :prop_sub) : sim_base_kwargs.prop_sub

        for (sid, nm) in enumerate(nm_grid)
            n = nm.n
            m = nm.m

            sim_kwargs = merge(sim_base_kwargs, (; n=n, m=m, H=H_reg, prop_sub=prop_sub_reg, tau_v=tau_v, tau_w=tau_w))

            for r in 1:nrep
                sim_seed = seed_base + 10_000_000*rid + 1_000_000*sid + 10_000*r

                # (1) simulate
                sim = simulate_bilinear_identity_1level(; seed=sim_seed, sim_kwargs...)
                X = sim.X
                Y = sim.Y
                B_true = sim.B_true
                subclass_of_met = sim.subclass_of_met
                H = sim_kwargs.H

                # (2) split
                (Y_tr, X_tr, Y_te, X_te, _, _) =
                    train_test_split_rows(Y, X; train_frac=split_frac, seed=sim_seed + 777)

                # (3) fit MatrixLM on train
                (B_hat, SE_hat, tstat, fitobj) = fit_matrixlm_identity(X_tr, Y_tr)

                # (4) Bayes fit on train
                (B_bayes, SE_bayes, res_by_cov) =
                    fit_bayes_all_covariates(B_hat, SE_hat, subclass_of_met, H;
                        mu0 = bayes_mu0,
                        s0 = bayes_s0,
                        halfcauchy_scale = bayes_halfcauchy_scale,
                        n_iter = bayes_n_iter,
                        burnin = bayes_burnin,
                        thin = bayes_thin,
                        seed0 = sim_seed + 2000,
                        keep_results = false
                    )

                # (5) coef MSE vs truth
                cmse_mlm = coef_mse(Float64.(B_hat),   Float64.(B_true))
                cmse_bys = coef_mse(Float64.(B_bayes), Float64.(B_true))

                # (6) test MSE on held-out rows
                tmse_mlm = test_mse(Float64.(Y_te), Float64.(X_te), Float64.(B_hat))
                tmse_bys = test_mse(Float64.(Y_te), Float64.(X_te), Float64.(B_bayes))
                tmse_mlm_scalar = tmse_mlm[1]
                tmse_bys_scalar = tmse_bys[1]

                push!(rows, (
                    regime = reg_name,
                    tau_v = Float64(tau_v),
                    tau_w = Float64(tau_w),
                    H = H_reg,
                    scenario = sid,
                    rep = r,
                    n = n,
                    m = m,
                    coef_mse_mlm = cmse_mlm,
                    coef_mse_bayes = cmse_bys,
                    test_mse_mlm = tmse_mlm_scalar,
                    test_mse_bayes = tmse_bys_scalar
                ))

                println("regime=$(reg_name) (tau_v=$(tau_v), tau_w=$(tau_w)) (n=$n,m=$m) rep=$r  coef_mse: mlm=$(cmse_mlm) bayes=$(cmse_bys)  test_mse: mlm=$(tmse_mlm_scalar) bayes=$(tmse_bys_scalar)")
            end
        end
    end

    per_rep = DataFrame(rows)

    #combine(groupby(per_rep, [:regime, :tau_v, :tau_w, :H, :n, :m])

    summary = combine(groupby(per_rep, [:regime, :n, :m]),
        :coef_mse_mlm   => mean => :coef_mse_mlm_mean,
        :coef_mse_bayes => mean => :coef_mse_bayes_mean,
        :test_mse_mlm   => mean => :test_mse_mlm_mean,
        :test_mse_bayes => mean => :test_mse_bayes_mean,
    )
    # Calculate relative improvements
    #summary.relimp_coef_mse_mean =
        #(summary.coef_mse_mlm_mean .- summary.coef_mse_bayes_mean) ./ summary.coef_mse_mlm_mean
    #summary.relimp_test_mse_mean = (summary.test_mse_mlm_mean .- summary.test_mse_bayes_mean) ./ summary.test_mse_mlm_mean

    # Calculate ratio of MSEs(MLM / Bayes)
    summary.ratio_coef_mse_mean =
        summary.coef_mse_mlm_mean ./ summary.coef_mse_bayes_mean
    summary.ratio_test_mse_mean =  
        summary.test_mse_mlm_mean ./ summary.test_mse_bayes_mean
    
    # Calculate deltas (Bayes - MLM)
    #summary.delta_coef_mse_mean = summary.coef_mse_bayes_mean .- summary.coef_mse_mlm_mean
    #summary.delta_test_mse_mean = summary.test_mse_bayes_mean .- summary.test_mse_mlm_mean

    rename!(summary,
    :regime => :Heterogeneity,
    :coef_mse_mlm_mean   => :CoefMSE_MatrixLM,
    :coef_mse_bayes_mean => :CoefMSE_Bayes,
    :test_mse_mlm_mean   => :TestMSE_MatrixLM,
    :test_mse_bayes_mean => :TestMSE_Bayes
)

    return (per_rep=per_rep, summary=summary)
end

run_simstudy_grid_1level_2hetero (generic function with 1 method)

In [30]:
sim_results_2 = run_simstudy_grid_1level_2hetero(nrep=10)

Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10
Done covariate k=10 / 10
regime=moderate (tau_v=0.12, tau_w=0.08) (n=60,m=300) rep=1  coef_mse: mlm=0.04043996628769653 bayes=0.009150102751953229  test_mse: mlm=1.476061775010522 bayes=1.1085503496218023
Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariate k=7 / 10
Done covariate k=8 / 10
Done covariate k=9 / 10
Done covariate k=10 / 10
regime=moderate (tau_v=0.12, tau_w=0.08) (n=60,m=300) rep=2  coef_mse: mlm=0.028241845250040257 bayes=0.0072952516935881435  test_mse: mlm=1.20847511804905 bayes=1.0150861852539286
Done covariate k=1 / 10
Done covariate k=2 / 10
Done covariate k=3 / 10
Done covariate k=4 / 10
Done covariate k=5 / 10
Done covariate k=6 / 10
Done covariat

(per_rep = [1m100×12 DataFrame[0m
[1m Row [0m│[1m regime   [0m[1m tau_v   [0m[1m tau_w   [0m[1m H     [0m[1m scenario [0m[1m rep   [0m[1m n     [0m[1m m     [0m[1m coef_[0m ⋯
     │[90m String   [0m[90m Float64 [0m[90m Float64 [0m[90m Int64 [0m[90m Int64    [0m[90m Int64 [0m[90m Int64 [0m[90m Int64 [0m[90m Float[0m ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ moderate     0.12     0.08      4         1      1     60    300    0.0 ⋯
   2 │ moderate     0.12     0.08      4         1      2     60    300    0.0
   3 │ moderate     0.12     0.08      4         1      3     60    300    0.0
   4 │ moderate     0.12     0.08      4         1      4     60    300    0.0
   5 │ moderate     0.12     0.08      4         1      5     60    300    0.0 ⋯
   6 │ moderate     0.12     0.08      4         1      6     60    300    0.0
   7 │ moderate     0.12     0.08      4         1      7     60    300    0.0
   

In [31]:
table2 = sim_results_2.summary

Row,Heterogeneity,n,m,CoefMSE_MatrixLM,CoefMSE_Bayes,TestMSE_MatrixLM,TestMSE_Bayes,ratio_coef_mse_mean,ratio_test_mse_mean
Unnamed: 0_level_1,String,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Float64
1,moderate,60,300,0.0318988,0.00784756,1.33542,1.07822,4.0648,1.23854
2,moderate,60,770,0.0313851,0.00740609,1.30092,1.07087,4.23775,1.21483
3,moderate,200,300,0.0080361,0.00367262,1.08206,1.03887,2.18811,1.04157
4,moderate,200,770,0.00777311,0.00356542,1.07931,1.03603,2.18014,1.04177
5,moderate,400,770,0.00366314,0.00235994,1.03713,1.02447,1.55221,1.01236
6,large,60,300,0.0317238,0.0117395,1.32118,1.11418,2.70232,1.18578
7,large,60,770,0.0326392,0.0116277,1.34641,1.12419,2.80702,1.19767
8,large,200,300,0.00755452,0.00503011,1.07082,1.04624,1.50186,1.0235
9,large,200,770,0.00772092,0.00507325,1.07956,1.05306,1.52189,1.02517
10,large,400,770,0.00372141,0.00295975,1.03473,1.02747,1.25734,1.00706


In [32]:
# Round all numeric columns to 4 decimal places
df_rounded = transform(table2, names(table2, Real) .=> ByRow(x -> round(x, digits = 3)) .=> names(table2, Real))
# Pick required columns
df_rounded_sel = select!(df_rounded, [:Heterogeneity, :n, :m, :ratio_coef_mse_mean, :ratio_test_mse_mean])

Row,Heterogeneity,n,m,ratio_coef_mse_mean,ratio_test_mse_mean
Unnamed: 0_level_1,String,Float64,Float64,Float64,Float64
1,moderate,60.0,300.0,4.065,1.239
2,moderate,60.0,770.0,4.238,1.215
3,moderate,200.0,300.0,2.188,1.042
4,moderate,200.0,770.0,2.18,1.042
5,moderate,400.0,770.0,1.552,1.012
6,large,60.0,300.0,2.702,1.186
7,large,60.0,770.0,2.807,1.198
8,large,200.0,300.0,1.502,1.024
9,large,200.0,770.0,1.522,1.025
10,large,400.0,770.0,1.257,1.007


In [33]:
# Export as LaTeX
pretty_table(
    df_rounded_sel;
    backend = Val(:latex),
    tf = tf_latex_booktabs,
    alignment = :c,
    show_subheader = false
)

\begin{tabular}{ccccc}
  \toprule
  \textbf{Heterogeneity} & \textbf{n} & \textbf{m} & \textbf{ratio\_coef\_mse\_mean} & \textbf{ratio\_test\_mse\_mean} \\\midrule
  moderate & 60.0 & 300.0 & 4.065 & 1.239 \\
  moderate & 60.0 & 770.0 & 4.238 & 1.215 \\
  moderate & 200.0 & 300.0 & 2.188 & 1.042 \\
  moderate & 200.0 & 770.0 & 2.18 & 1.042 \\
  moderate & 400.0 & 770.0 & 1.552 & 1.012 \\
  large & 60.0 & 300.0 & 2.702 & 1.186 \\
  large & 60.0 & 770.0 & 2.807 & 1.198 \\
  large & 200.0 & 300.0 & 1.502 & 1.024 \\
  large & 200.0 & 770.0 & 1.522 & 1.025 \\
  large & 400.0 & 770.0 & 1.257 & 1.007 \\\bottomrule
\end{tabular}
