In [1]:
# load necessary packages; make sure install them first
using BenchmarkTools, Distributions, LinearAlgebra, Random, Revise

In [3]:
struct LmmObs{T <: AbstractFloat}
    # data
    y          :: Vector{T}
    X          :: Matrix{T}
    Z          :: Matrix{T}
    # posterior mean and variance of random effects γ
    μγ         :: Vector{T} # posterior mean of random effects
    νγ         :: Matrix{T} # posterior variance of random effects
    # TODO: add whatever intermediate arrays you may want to pre-allocate
    yty        :: T
    rtr        :: Vector{T}
    xty        :: Vector{T}
    zty        :: Vector{T}
    ztr        :: Vector{T}
    ltztr      :: Vector{T}
    xtr        :: Vector{T}
    storage_p  :: Vector{T}
    storage_q  :: Vector{T}
    xtx        :: Matrix{T}
    ztx        :: Matrix{T}
    ztz        :: Matrix{T}
    ltztzl     :: Matrix{T}
    storage_qq :: Matrix{T}
end

"""
    LmmObs(y::Vector, X::Matrix, Z::Matrix)

Create an LMM datum of type `LmmObs`.
"""
function LmmObs(
    y::Vector{T}, 
    X::Matrix{T}, 
    Z::Matrix{T}) where T <: AbstractFloat
    n, p, q = size(X, 1), size(X, 2), size(Z, 2)
    μγ         = Vector{T}(undef, q)
    νγ         = Matrix{T}(undef, q, q)
    yty        = abs2(norm(y))
    rtr        = Vector{T}(undef, 1)
    xty        = transpose(X) * y
    zty        = transpose(Z) * y
    ztr        = similar(zty)
    ltztr      = similar(zty)
    xtr        = Vector{T}(undef, p)
    storage_p  = similar(xtr)
    storage_q  = Vector{T}(undef, q)
    xtx        = transpose(X) * X
    ztx        = transpose(Z) * X
    ztz        = transpose(Z) * Z
    ltztzl     = similar(ztz)
    storage_qq = similar(ztz)
    LmmObs(y, X, Z, μγ, νγ, 
        yty, rtr, xty, zty, ztr, ltztr, xtr,
        storage_p, storage_q, 
        xtx, ztx, ztz, ltztzl, storage_qq)
end


Random.seed!(257)
# dimension
n, p, q = 2000, 5, 3
# predictors
X = [ones(n) randn(n, p - 1)]
Z = [ones(n) randn(n, q - 1)]
# parameter values
β  = [2.0; -1.0; rand(p - 2)]
σ² = 1.5
Σ  = fill(0.1, q, q) + 0.9I # compound symmetry 
L  = Matrix(cholesky(Symmetric(Σ)).L)
# generate y
y  = X * β + Z * rand(MvNormal(Σ)) + sqrt(σ²) * randn(n)

# form the LmmObs object
obs = LmmObs(y, X, Z);

In [6]:
μγ         = Vector{Float64}(undef, q)
νγ         = Matrix{Float64}(undef, q, q)
yty        = abs2(norm(y))
rtr        = Vector{Float64}(undef, 1)
xty        = transpose(X) * y
zty        = transpose(Z) * y
ztr        = similar(zty)
ltztr      = similar(zty)
xtr        = Vector{Float64}(undef, p)
storage_p  = similar(xtr)
storage_q  = Vector{Float64}(undef, q)
xtx        = transpose(X) * X
ztx        = transpose(Z) * X
ztz        = transpose(Z) * Z
ltztzl     = similar(ztz)
storage_qq = similar(ztz)
Linv       = Matrix{Float64}(undef, q, q)

n, p, q = size(X, 1), size(X, 2), size(Z, 2)
σ²inv   = inv(σ²)

copy!(obs.ltztzl, obs.ztz)
BLAS.trmm!('L', 'L', 'T', 'N', T(1), L, obs.ltztzl) # O(q^3) obs.ltztzl = Zt Z L
BLAS.trmm!('R', 'L', 'N', 'N', T(1), L, obs.ltztzl) # O(q^3) obs.ltztzl = Lt Zt Z L
# form the q-by-q matrix: M = σ² I + Lt Zt Z L
copy!(obs.storage_qq, obs.ltztzl)
@inbounds for j in 1:q
    obs.storage_qq[j, j] += σ² # obs.storage_qq = σ² I + Lt Zt Z L
end
LAPACK.potrf!('U', obs.storage_qq) # O(q^3) # obs.storage_qq = Rt
# Zt * res
updater && BLAS.gemv!('N', -1.0, obs.ztx, β, 1.0, copy!(obs.ztr, obs.zty)) # O(pq)
# Lt * (Zt * res)
BLAS.trmv!('L', 'T', 'N', L, copy!(obs.ltztr, obs.ztr))    # O(q^2)
# storage_q = (Mchol.U') \ (Lt * (Zt * res))
BLAS.trsv!('U', 'T', 'N', obs.storage_qq, copy!(obs.storage_q, obs.ltztr)) # O(q^3)
# Xt * res = Xt * y - Xt * X * β
updater && BLAS.gemv!('N', -1.0, obs.xtx, β, 1.0, copy!(obs.xtr, obs.xty))
    # l2 norm of residual vector
    updater && (obs.rtr[1] = obs.yty - dot(obs.xty, β) - dot(obs.xtr, β))
    # assemble pieces
    logl::T = n * log(2π) + (n - q) * log(σ²) # constant term
    @inbounds for j in 1:q # log det term
        logl += 2log(obs.storage_qq[j, j])
    end
    qf    = abs2(norm(obs.storage_q)) # quadratic form term
    logl += (obs.rtr[1] - qf) * σ²inv 
    logl /= -2

UndefVarError: UndefVarError: T not defined