## BIOSTAT 257 Homework 2

Consider a linear mixed effects model

$$Y_i = X_i\beta + Z_i\gamma + \epsilon_i, \quad i = 1,\ldots,n$$

where

- $Y_i \in \mathbb{R}^{n_i}$ is the reponse vector of the $i$-th individual,
- $X_i \in \mathbb{R}^{n_i \times p}$ is the fixed effect predictor matrix of  $i$-th individual,
- $Z_i \in \mathbb{R}^{n_i \times q}$ is the random effect predictor matrix of  $i$-th individual, 
- $\epsilon_i \in \mathbb{R}^{n_i}$ are multivariate normal $N(0_{n_i}, \sigma^2 I_{n_i})$,
- $\beta \in \mathbb{R}^{p}$ are fixed effects, and
- $\gamma \in \mathbb{R}^{q}$ are random effects assumed to be $N(0_{q}, \Sigma_{q \times q})$ independent of $\epsilon_i$.

### Question 1: Formula

Write down the log-likelihood of the  $i$-th datum  $(Y_i,X_i,Z_i)$  given parameters $(\beta,\Sigma,\sigma^2)$.

The marginal distribution of $Y_i \sim N(X_i \beta, Z_i \Sigma Z_i^T + \sigma^2 I_{n_i})$

$$\ell(\beta,\Sigma,\sigma^2) = -\frac{n_i}{2}\text{log}(2\pi) - \frac{1}{2}\text{log}|Z_i \Sigma Z_i^T + \sigma^2 I_{n_i}| - \frac{1}{2}(Y_i - X_i \beta)^T(Z_i \Sigma Z_i^T + \sigma^2 I_{n_i})^{-1}(Y_i - X_i \beta)$$

The most computationally challenging terms will be the log determinant and the matrix inversion. First we know that $\Sigma = L L'$. We can write 

$$Z\Sigma Z^T = Z L L^T Z^T = (ZL)(ZL)^T = RR^T$$

We can use Woodbury to rewrite the covariance matrix as:

$$(RR^T + \sigma^2 I_{n_i})^{-1} = \frac{1}{\sigma^2}I -\frac{1}{\sigma^4}R\Bigg(I + \frac{1}{\sigma^2}R^TR\Bigg)^{-1}R^T.$$

Then, the quadratic form then becomes: 

$$\frac{1}{\sigma^2}(y-X\beta)^T(y-X\beta) -\frac{1}{\sigma^4}(y-X\beta)^TR\Bigg(I + \frac{1}{\sigma^2}R^TR\Bigg)^{-1}R^T(y-X\beta)$$

where if we let $C = R^T(y-X\beta) = L^TZ^T(y-X\beta)$, this then becomes:

$$\frac{1}{\sigma^2}(y-X\beta)^T(y-X\beta) -\frac{1}{\sigma^4}C^T\Bigg(I + \frac{1}{\sigma^2}R^TR\Bigg)^{-1}C$$

$$I + \frac{1}{\sigma^2}R^TR = MM^T.$$

Which leads us to: 

$$(RR^T + \sigma^2 I_{n_i})^{-1} = \frac{1}{\sigma^2}I -\frac{1}{\sigma^4}R\Bigg(MM^T\Bigg)^{-1}R^T = \frac{1}{\sigma^2}I -\frac{1}{\sigma^4}RM^{-T}M^{-1}R^T$$


$$(RR^T + \sigma^2 I_{n_i})^{-1} = \frac{1}{\sigma^2}I -\frac{1}{\sigma^4}Q^TQ$$

such that $M^{-1}R^T = Q$. If we let $e = y - X \beta$, then we can write the quadratic form as:

$$e^T(\frac{1}{\sigma^2}I -\frac{1}{\sigma^4}Q^TQ)e = \frac{1}{\sigma^2}e^Te - \frac{1}{\sigma^4} e^TQ^TQe$$

The other difficult term to work with is 

$$\text{det}(\sigma^2I + Z\Sigma Z^T) = \text{det}(\sigma^2I)\text{det}\Bigg(I + \frac{1}{\sigma^2}R^TR\Bigg) = (\sigma^2)^n\text{det}(MM^T)$$

Thus we can re-write the log-likelihood as

$$\ell(\beta,\Sigma,\sigma^2) = -\frac{n_i}{2}\text{log}(2\pi) - \frac{1}{2}log((\sigma^2)^n\text{det}(MM^T)) - \frac{1}{2\sigma^2}e^Te + \frac{1}{2\sigma^4} e^TQ^TQe$$

$$\ell(\beta,\Sigma,\sigma^2) = -\frac{n_i}{2}\text{log}(2\pi) - \frac{n_i}{2}log(\sigma^2) -\frac{1}{2}log(\text{det}(MM^T)) - \frac{1}{2\sigma^2}e^Te + \frac{1}{2\sigma^4} e^TQ^TQe$$

### Question 2: Start-up Code

Use the following template to define a type `LmmObs` that holds an LMM datum $(y_i,X_i,Z_i)$.

In [18]:
using BenchmarkTools, Distributions, LinearAlgebra, Random, SparseArrays, InteractiveUtils

# define a type that holds LMM datum
struct LmmObs{T <: AbstractFloat}
    # data
    y :: Vector{T}
    X :: Matrix{T}
    Z :: Matrix{T}
    # working arrays
    # whatever intermediate arrays you may want to pre-allocate
    res         :: Vector{T}
    storage_q   :: Vector{T}
    storage_q2  :: Vector{T}
    ztz         :: Matrix{T}
    storage_qq  :: Matrix{T}
    storage_qq2 :: Matrix{T}
end

# constructor
function LmmObs(
        y::Vector{T}, 
        X::Matrix{T}, 
        Z::Matrix{T}) where T <: AbstractFloat
    res         = similar(y)
    storage_q   = Vector{T}(undef, size(Z, 2))
    storage_q2  = Vector{T}(undef, size(Z, 2))
    ztz         = transpose(Z) * Z
    storage_qq  = similar(ztz)
    storage_qq2 = similar(ztz)
    LmmObs(y, X, Z, res, storage_q, storage_q2, ztz, storage_qq, storage_qq2)
end

LmmObs

Write a function, with interface `logl!(obs, β, L, σ²)` that evaluates the log-likelihood of the $i$-th datum. Here `L` is the lower triangular Cholesky factor from the Cholesky decomposition `Σ=LL'`. Make your code efficient in the $n_i≫q$ case. Think the intensive longitudinal measurement setting.

In [19]:
function logl!(
        obs :: LmmObs{T}, 
        β   :: Vector{T}, 
        L   :: Matrix{T}, 
        σ²  :: T) where T <: AbstractFloat
    n, p, q = size(obs.X, 1), size(obs.X, 2), size(obs.Z, 2) 
    ## Calculate y - Xβ
    mul!(obs.res, obs.X, β)
    axpy!(-1, obs.y, obs.res)
    
    ## Start calculating (I + (1/σ^2)R^tR)
    mul!(obs.storage_qq, obs.ztz, L)
    mul!(obs.storage_qq2, L', obs.storage_qq)
    mul!(obs.ztz, obs.storage_qq2, (1/σ²))
    for i = 1:q
        obs.ztz[i, i] += 1
    end
    
    ## Cholesky Decomposition of (I + (1/σ^2)R^tR)
    M = cholesky!(Symmetric(obs.ztz))

    ## M⁻¹ Lᵀ Zᵀ (y - Xβ)
    mul!(obs.storage_q, obs.Z', obs.res)
    mul!(obs.storage_q2, L', obs.storage_q)
    ldiv!(obs.storage_q, M, obs.storage_q2)
    
    return -n / 2 * log(2 * π) - n / 2 * log(σ²) - 1 / 2 * logdet(M) - 
        1/ (2 * σ²) * dot(obs.res, obs.res) + 
        1 / (2 * (σ²)^2) * dot(obs.storage_q2, obs.storage_q)
end

logl! (generic function with 1 method)

Hint: This function shouldn't be very long. Mine, obeying 80-character rule, is 25 lines. If you find yourself writing very long code, you're on the wrong track. Think about algorithm first then use BLAS functions to reduce memory allocations.

### Question 3: Correctness

Compare your result (both accuracy and timing) to the Distributions.jl package using following data.


In [22]:
Random.seed!(257)
# dimension
n, p, q = 2000, 5, 3
# predictors
X  = [ones(n) randn(n, p - 1)]
Z  = [ones(n) randn(n, q - 1)]
# parameter values
β  = [2.0; -1.0; rand(p - 2)]
σ² = 1.5
Σ  = fill(0.1, q, q) + 0.9I
# generate y
y  = X * β + Z * rand(MvNormal(Σ)) + sqrt(σ²) * randn(n)

# form an LmmObs object
obs = LmmObs(y, X, Z)

LmmObs{Float64}([5.739048710854997, 5.705395720270055, 2.7368899643050355, 1.4201223592870755, -0.2099433929180451, 3.5886971824690486, -1.3778538474575956, -0.08406026821055246, -2.208007878450787, 1.309558511583542  …  1.2947876180172684, -1.9701265304395086, -2.040383092851745, -1.4590296825658675, 0.18616271231054726, 1.0681247149968018, 2.2292080864625254, 1.1952385354603545, 1.1310626949609701, -0.43507816286713785], [1.0 -2.506566300781151 … 0.5863780184080776 1.1092991040518192; 1.0 -0.974090320735282 … 1.4143507320583761 0.45608259198567447; … ; 1.0 -1.0076371084863895 … -1.3241972696483915 1.4547609424344008; 1.0 0.38036793320364776 … -0.5857507269707397 1.796804266836504], [1.0 -0.6380567326757537 1.4738982136806946; 1.0 -2.0711110232845926 0.21422658785510312; … ; 1.0 0.5917731507133951 -0.9163364468263059; 1.0 0.9463732120394507 -0.325860403600768], [2.49199834e-316, 2.121995791e-314, 0.0, 2.49199953e-316, 2.121995791e-314, 2.1191752e-316, 2.4920007e-316, 2.121995791e-314,

In [21]:
logl!(obs, β, Matrix(cholesky(Σ).L), σ²) 

-3247.4568580638297

In [17]:
μ  = X * β
Ω  = Z * Σ * transpose(Z) +  σ² * I
mvn = MvNormal(μ, Symmetric(Ω)) # MVN(μ, Σ)
logpdf(mvn, y)

-3247.456858063827

In [23]:
@assert logl!(obs, β, Matrix(cholesky(Σ).L), σ²) ≈ logpdf(mvn, y)

In [9]:
bm1 = @benchmark logpdf($mvn, $y)

BenchmarkTools.Trial: 
  memory estimate:  30.55 MiB
  allocs estimate:  5
  --------------
  minimum time:     7.384 ms (0.00% GC)
  median time:      7.941 ms (0.00% GC)
  mean time:        17.586 ms (4.79% GC)
  maximum time:     39.075 ms (1.37% GC)
  --------------
  samples:          285
  evals/sample:     1

In [10]:
L = Matrix(cholesky(Σ).L)
bm2 = @benchmark logl!($obs, $β, $L, $σ²)

BenchmarkTools.Trial: 
  memory estimate:  272 bytes
  allocs estimate:  7
  --------------
  minimum time:     7.400 μs (0.00% GC)
  median time:      12.350 μs (0.00% GC)
  mean time:        13.879 μs (0.00% GC)
  maximum time:     46.533 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     3

In [11]:
clamp(median(bm1).time / median(bm2).time / 1000 * 30, 0, 30)

19.289635627530362

In [12]:
clamp(30 - median(bm2).memory / 1024, 0, 30)

29.734375