## BIOSTAT 257: Homework 5
### Joanna Boland

Again we continue with the linear mixed effects model (LMM)
$$
    \mathbf{Y}_i = \mathbf{X}_i \boldsymbol{\beta} + \mathbf{Z}_i \boldsymbol{\gamma} + \boldsymbol{\epsilon}_i, \quad i=1,\ldots,n,
$$
where   
- $\mathbf{Y}_i \in \mathbb{R}^{n_i}$ is the response vector of $i$-th individual,  
- $\mathbf{X}_i \in \mathbb{R}^{n_i \times p}$ is the fixed effects predictor matrix of $i$-th individual,  
- $\mathbf{Z}_i \in \mathbb{R}^{n_i \times q}$ is the random effects predictor matrix of $i$-th individual,  
- $\boldsymbol{\epsilon}_i \in \mathbb{R}^{n_i}$ are multivariate normal $N(\mathbf{0}_{n_i},\sigma^2 \mathbf{I}_{n_i})$,  
- $\boldsymbol{\beta} \in \mathbb{R}^p$ are fixed effects, and  
- $\boldsymbol{\gamma} \in \mathbb{R}^q$ are random effects assumed to be $N(\mathbf{0}_q, \boldsymbol{\Sigma}_{q \times q}$) independent of $\boldsymbol{\epsilon}_i$.

The log-likelihood of the $i$-th datum $(\mathbf{y}_i, \mathbf{X}_i, \mathbf{Z}_i)$ is 
$$
    \ell_i(\boldsymbol{\beta}, \mathbf{L}, \sigma_0^2) = - \frac{n_i}{2} \log (2\pi) - \frac{1}{2} \log \det \boldsymbol{\Omega}_i - \frac{1}{2} (\mathbf{y} - \mathbf{X}_i \boldsymbol{\beta})^T \boldsymbol{\Omega}_i^{-1} (\mathbf{y} - \mathbf{X}_i \boldsymbol{\beta}),
$$
where
$$
    \boldsymbol{\Omega}_i = \sigma^2 \mathbf{I}_{n_i} + \mathbf{Z}_i \boldsymbol{\Sigma} \mathbf{Z}_i^T.
$$
Given $m$ independent data points $(\mathbf{y}_i, \mathbf{X}_i, \mathbf{Z}_i)$, $i=1,\ldots,m$, we seek the maximum likelihood estimate (MLE) by maximizing the log-likelihood
$$
\ell(\boldsymbol{\beta}, \boldsymbol{\Sigma}, \sigma_0^2) = \sum_{i=1}^m \ell_i(\boldsymbol{\beta}, \boldsymbol{\Sigma}, \sigma_0^2).
$$

In HW4, we used the nonlinear programming (NLP) approach (Newton type algorithms) for optimization. In this assignment, we derive and implement an expectation-maximization (EM) algorithm for the same problem.

In [1]:
# load necessary packages; make sure install them first
using BenchmarkTools, Distributions, LinearAlgebra, Random, Revise

### Question 1: Refresher on Normal-Normal Model

Assume the conditional distribution
$$
\mathbf{y} \mid \boldsymbol{\gamma} \sim N(\mathbf{X} \boldsymbol{\beta} + \mathbf{Z} \boldsymbol{\gamma}, \sigma^2 \mathbf{I}_n)
$$
and the prior distribution
$$
\boldsymbol{\gamma} \sim N(\mathbf{0}_q, \boldsymbol{\Sigma}).
$$
By the Bayes theorem, the posterior distribution is
\begin{eqnarray*}
f(\boldsymbol{\gamma} \mid \mathbf{y}) &=& \frac{f(\mathbf{y} \mid \boldsymbol{\gamma}) \times f(\boldsymbol{\gamma})}{f(\mathbf{y})}, \end{eqnarray*}
where $f$ denotes corresponding density. 

Note that
\begin{eqnarray*}
f(\boldsymbol{\gamma}) &\propto& \text{exp}\Bigg(-\frac{1}{2}\boldsymbol{\gamma}^T\boldsymbol{\Sigma}^{-1} \boldsymbol{\gamma}\Bigg), \\ 
f(\mathbf{y} \mid \boldsymbol{\gamma}) &\propto& \text{exp}\Bigg(-\frac{1}{2}(\mathbf{y} - \mathbf{X} \boldsymbol{\beta} + \mathbf{Z} \boldsymbol{\gamma})^T(\sigma^2 \mathbf{I}_n)^{-1} (\mathbf{y} - \mathbf{X} \boldsymbol{\beta} + \mathbf{Z} \boldsymbol{\gamma})\Bigg) \\
f(\mathbf{y} \mid \boldsymbol{\gamma}) &\propto& \text{exp}\Bigg(-\frac{1}{2}\sigma^{-2}\boldsymbol{\gamma}^T\mathbf{Z}^T\mathbf{Z}\boldsymbol{\gamma} - \sigma^{-2}\boldsymbol{\gamma}^T\mathbf{Z}^T(\mathbf{y} - \mathbf{X} \boldsymbol{\beta})\Bigg) \\
f(\boldsymbol{\gamma} \mid \mathbf{y}) &\propto& f(\mathbf{y} \mid \boldsymbol{\gamma}) f(\boldsymbol{\gamma}) \\
f(\boldsymbol{\gamma} \mid \mathbf{y}) &\propto& \text{exp}\Bigg(-\frac{1}{2}\boldsymbol{\gamma}^T\boldsymbol{\Sigma}^{-1} \boldsymbol{\gamma} -\frac{1}{2}\sigma^{-2}\boldsymbol{\gamma}^T\mathbf{Z}^T\mathbf{Z}\boldsymbol{\gamma} - \sigma^{-2}\boldsymbol{\gamma}^T\mathbf{Z}^T(\mathbf{y} - \mathbf{X} \boldsymbol{\beta})\Bigg) \\
f(\boldsymbol{\gamma} \mid \mathbf{y}) &\propto& \text{exp}\Bigg(-\frac{1}{2}\boldsymbol{\gamma}^T(\boldsymbol{\Sigma}^{-1} + \sigma^{-2}\mathbf{Z}^T\mathbf{Z}) \boldsymbol{\gamma}^T - \sigma^{-2}\boldsymbol{\gamma}^T\mathbf{Z}^T(\mathbf{y} - \mathbf{X} \boldsymbol{\beta})\Bigg)
\end{eqnarray*}

Therefore, using properties of normal distributions and completing the square, we know that 

$$\mathbf{y} \mid \boldsymbol{\gamma} \sim N(A^{-1}b, A^{-1})$$,

where
$$A = \boldsymbol{\Sigma}^{-1} + \sigma^{-2}\mathbf{Z}^T\mathbf{Z}, \quad b = \sigma^{-2}\mathbf{Z}^T(\mathbf{y} - \mathbf{X} \boldsymbol{\beta})$$

Therefore, by the Woodbury Identity

$$\text{Var} (\boldsymbol{\gamma} \mid \mathbf{y}) = (\boldsymbol{\Sigma}^{-1} + \sigma^{-2}\mathbf{Z}^T\mathbf{Z})^{-1}$$
$$\text{Var} (\boldsymbol{\gamma} \mid \mathbf{y}) = \boldsymbol{\Sigma} - \boldsymbol{\Sigma}\mathbf{Z}^T(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)^{-1}\mathbf{Z}\boldsymbol{\Sigma}$$,

and additionally

\begin{eqnarray*}
\mathbb{E} (\boldsymbol{\gamma} \mid \mathbf{y}) &=& \sigma^{-2} (\sigma^{-2} \mathbf{Z}^T \mathbf{Z} + \boldsymbol{\Sigma}^{-1})^{-1 } \mathbf{Z}^T (\mathbf{y} - \mathbf{X} \boldsymbol{\beta}) \\
&=& \sigma^{-2}(\boldsymbol{\Sigma} - \boldsymbol{\Sigma}\mathbf{Z}^T(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)^{-1}\mathbf{Z}\boldsymbol{\Sigma})\mathbf{Z}^T (\mathbf{y} - \mathbf{X} \boldsymbol{\beta}) \\
&=& \boldsymbol{\Sigma}\mathbf{Z}^T(\sigma^{-2}\mathbf{I} - \sigma^{-2}(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)^{-1}\mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)(\mathbf{y} - \mathbf{X} \boldsymbol{\beta}) \\
&=& \boldsymbol{\Sigma}\mathbf{Z}^T(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)^{-1}(\sigma^{-2}\mathbf{I}(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T) - \sigma^{-2}\mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)(\mathbf{y} - \mathbf{X} \boldsymbol{\beta}) \\
&=& \boldsymbol{\Sigma}\mathbf{Z}^T(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)^{-1}(\mathbf{I} + \sigma^{-2}\mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T - \sigma^{-2}\mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)(\mathbf{y} - \mathbf{X} \boldsymbol{\beta}) \\
&=& \boldsymbol{\Sigma}\mathbf{Z}^T(\sigma^{2}\mathbf{I} + \mathbf{Z}\boldsymbol{\Sigma}\mathbf{Z}^T)^{-1}(\mathbf{y} - \mathbf{X} \boldsymbol{\beta})
\end{eqnarray*}

### Question 2: Derive EM Algorithm

1. Write down the complete log-likelihood

\begin{eqnarray*}
\sum_{i=1}^m \log f(\mathbf{y}_i, \boldsymbol{\gamma}_i \mid \boldsymbol{\beta}, \boldsymbol{\Sigma}, \sigma^2) &=& \sum_{i=1}^m [\log f(\mathbf{y}_i \mid  \boldsymbol{\gamma}_i,\boldsymbol{\beta}, \sigma^2) + \log f( \boldsymbol{\gamma}_i \mid \boldsymbol{\Sigma})] \\
&=& \sum_{i=1}^m \Bigg[- \frac{n_i}{2} \log (2\pi) - \frac{1}{2} \log \det (\sigma^{2}\mathbf{I}_{n_i}) - \frac{1}{2} (\mathbf{y}_i - \mathbf{X}_i \boldsymbol{\beta} - \mathbf{Z}_i \boldsymbol{\gamma}_i)^T \sigma^{-2}\mathbf{I}_{n_i} (\mathbf{y}_i - \mathbf{X}_i \boldsymbol{\beta} - \mathbf{Z}_i \boldsymbol{\gamma}_i)
- \frac{q}{2} \log (2\pi) - \frac{1}{2} \log \det \boldsymbol{\Sigma} - \frac{1}{2} \boldsymbol{\gamma}_i^T \Sigma^{-1}\boldsymbol{\gamma}_i\Bigg] \\
&=& \sum_{i=1}^m \Bigg[- \frac{n_i}{2} \log (2\pi) - \frac{1}{2} \log \det (\sigma^{2}\mathbf{I}_{n_i}) - \frac{1}{2} (\mathbf{y}_i - \mathbf{X}_i \boldsymbol{\beta} - \mathbf{Z}_i \boldsymbol{\gamma}_i)^T \sigma^{-2}\mathbf{I}_{n_i} (\mathbf{y}_i - \mathbf{X}_i \boldsymbol{\beta} - \mathbf{Z}_i \boldsymbol{\gamma}_i)
 - \frac{1}{2} \boldsymbol{\gamma}_i^T \Sigma^{-1}\boldsymbol{\gamma}_i\Bigg] - \frac{qm}{2} \log (2\pi) - \frac{m}{2} \log \det \boldsymbol{\Sigma}
\end{eqnarray*}

2. Derive the $Q$ function (E-step).

\begin{eqnarray*}
Q(\boldsymbol{\beta}, \boldsymbol{\Sigma}, \sigma^2 \mid \boldsymbol{\beta}^{(t)}, \boldsymbol{\Sigma}^{(t)}, \sigma^{2(t)})
\end{eqnarray*}


### Question 3: Objective of a single datum

We modify the code from HW4 to evaluate the objective, the conditional mean of $\boldsymbol{\gamma}$, and the conditional variance of $\boldsymbol{\gamma}$. Start-up code is provided below. You do _not_ have to use this code.

In [5]:
# define a type that holds an LMM datum
struct LmmObs{T <: AbstractFloat}
    # data
    y          :: Vector{T}
    X          :: Matrix{T}
    Z          :: Matrix{T}
    # posterior mean and variance of random effects γ
    μγ         :: Vector{T} # posterior mean of random effects
    νγ         :: Matrix{T} # posterior variance of random effects
    # TODO: add whatever intermediate arrays you may want to pre-allocate
    yty        :: T
    rtr        :: Vector{T}
    xty        :: Vector{T}
    zty        :: Vector{T}
    ztr        :: Vector{T}
    ltztr      :: Vector{T}
    xtr        :: Vector{T}
    storage_p  :: Vector{T}
    storage_q  :: Vector{T}
    xtx        :: Matrix{T}
    ztx        :: Matrix{T}
    ztz        :: Matrix{T}
    ltztzl     :: Matrix{T}
    storage_qq :: Matrix{T}
    I3         :: Matrix{T}
    Linv       :: Matrix{T}
    storage_qq2:: Matrix{T}
end

"""
    LmmObs(y::Vector, X::Matrix, Z::Matrix)

Create an LMM datum of type `LmmObs`.
"""
function LmmObs(
    y::Vector{T}, 
    X::Matrix{T}, 
    Z::Matrix{T}) where T <: AbstractFloat
    n, p, q = size(X, 1), size(X, 2), size(Z, 2)
    μγ         = Vector{T}(undef, q)
    νγ         = Matrix{T}(undef, q, q)
    yty        = abs2(norm(y))
    rtr        = Vector{T}(undef, 1)
    xty        = transpose(X) * y
    zty        = transpose(Z) * y
    ztr        = similar(zty)
    ltztr      = similar(zty)
    xtr        = Vector{T}(undef, p)
    storage_p  = similar(xtr)
    storage_q  = Vector{T}(undef, q)
    xtx        = transpose(X) * X
    ztx        = transpose(Z) * X
    ztz        = transpose(Z) * Z
    ltztzl     = similar(ztz)
    storage_qq = similar(ztz)
    I3         = Matrix{T}(I, q, q)
    Linv       = Matrix{T}(undef, q, q)
    storage_qq2 = Matrix{T}(undef, q, q)
    LmmObs(y, X, Z, μγ, νγ, 
        yty, rtr, xty, zty, ztr, ltztr, xtr,
        storage_p, storage_q, 
        xtx, ztx, ztz, ltztzl, storage_qq, 
        I3, Linv, storage_qq2)
end

"""
    logl!(obs::LmmObs, β, Σ, L, σ², updater = false)

Evaluate the log-likelihood of a single LMM datum at parameter values `β`, `Σ`, 
and `σ²`. The lower triangular Cholesky factor `L` of `Σ` must be supplied too.
The fields `obs.μγ` and `obs.νγ` are overwritten by the posterior mean and 
posterior variance of random effects. If `updater==true`, fields `obs.ztr`, 
`obs.xtr`, and `obs.rtr` are updated according to input parameter values. 
Otherwise, it assumes these three fields are pre-computed. 
"""
function logl!(
        obs     :: LmmObs{T}, 
        β       :: Vector{T}, 
        Σ       :: Matrix{T},
        L       :: Matrix{T},
        σ²      :: T,
        updater :: Bool = false
        ) where T <: AbstractFloat
    n, p, q = size(obs.X, 1), size(obs.X, 2), size(obs.Z, 2)
    σ²inv   = inv(σ²)
    ####################
    # Evaluate objective
    ####################
    # form the q-by-q matrix: Lt Zt Z L
    copy!(obs.ltztzl, obs.ztz)
    BLAS.trmm!('L', 'L', 'T', 'N', T(1), L, obs.ltztzl) # O(q^3) obs.ltztzl = Zt Z L
    BLAS.trmm!('R', 'L', 'N', 'N', T(1), L, obs.ltztzl) # O(q^3) obs.ltztzl = Lt Zt Z L
    # form the q-by-q matrix: M = σ² I + Lt Zt Z L
    copy!(obs.storage_qq, obs.ltztzl)
    @inbounds for j in 1:q
        obs.storage_qq[j, j] += σ² # obs.storage_qq = σ² I + Lt Zt Z L
    end
    LAPACK.potrf!('U', obs.storage_qq) # O(q^3) # obs.storage_qq = Rt
    # Zt * res
    updater && BLAS.gemv!('N', T(-1), obs.ztx, β, T(1), copy!(obs.ztr, obs.zty)) # O(pq)
    # Lt * (Zt * res)
    BLAS.trmv!('L', 'T', 'N', L, copy!(obs.ltztr, obs.ztr))    # O(q^2)
    # storage_q = (Mchol.U') \ (Lt * (Zt * res))
    BLAS.trsv!('U', 'T', 'N', obs.storage_qq, copy!(obs.storage_q, obs.ltztr)) # O(q^3)
    # Xt * res = Xt * y - Xt * X * β
    updater && BLAS.gemv!('N', T(-1), obs.xtx, β, T(1), copy!(obs.xtr, obs.xty))
    # l2 norm of residual vector
    updater && (obs.rtr[1] = obs.yty - dot(obs.xty, β) - dot(obs.xtr, β))
    # assemble pieces
    logl::T = n * log(2π) + (n - q) * log(σ²) # constant term
    @inbounds for j in 1:q # log det term
        logl += 2log(obs.storage_qq[j, j])
    end
    qf    = abs2(norm(obs.storage_q)) # quadratic form term
    logl += (obs.rtr[1] - qf) * σ²inv 
    logl /= -2
    ######################################
    # TODO: Evaluate posterior mean and variance
    ######################################    
    
    # Calculate Variance
    BLAS.trsm!('R', 'L', 'N', 'N', T(1), L, copy!(obs.Linv, obs.I3)) 
    BLAS.gemm!('T', 'N', T(1), obs.Linv, obs.Linv, σ²inv, copy!(obs.storage_qq2, obs.ztz))
    LAPACK.potrf!('L', obs.storage_qq2)
    BLAS.trsm!('R', 'L', 'N', 'N', T(1), obs.storage_qq2, copy!(obs.νγ, obs.I3))
    BLAS.trmm!('L', 'L', 'T', 'N', T(1), obs.νγ, obs.νγ)

    # Calculate Expected Value
    BLAS.gemv!('N', σ²inv, obs.νγ, obs.ztr, T(0), obs.μγ)
    
    ###################
    # Return
    ###################        
    return logl
end

logl!

It is a good idea to test correctness and efficiency of the single datum objective/posterior mean/var evaluator here. It's the same test datum in HW2 and HW4.

In [6]:
Random.seed!(257)
# dimension
n, p, q = 2000, 5, 3
# predictors
X = [ones(n) randn(n, p - 1)]
Z = [ones(n) randn(n, q - 1)]
# parameter values
β  = [2.0; -1.0; rand(p - 2)]
σ² = 1.5
Σ  = fill(0.1, q, q) + 0.9I # compound symmetry 
L  = Matrix(cholesky(Symmetric(Σ)).L)
# generate y
y  = X * β + Z * rand(MvNormal(Σ)) + sqrt(σ²) * randn(n)

# form the LmmObs object
obs = LmmObs(y, X, Z);

#### Correctness

In [7]:
@show logl = logl!(obs, β, Σ, L, σ², true)
@show obs.μγ
@show obs.νγ;

logl = logl!(obs, β, Σ, L, σ², true) = -3247.4568580638243
obs.μγ = [-1.7352999248283547, -1.2234665777048983, -0.25020190407763465]
obs.νγ = [0.0007495521480103862 4.188026819522356e-6 8.595028349011145e-6; 4.188026819522356e-6 0.0007599372708603274 -1.0092121486077345e-5; 8.595028349011145e-6 -1.0092121486077345e-5 0.0007370698232610101]


You will lose all 20 points if following statement throws `AssertionError`.

In [8]:
@assert abs(logl - (-3247.4568580638247)) < 1e-8
@assert norm(obs.μγ - [-1.7352999248278138, 
        -1.2234665777052611, -0.25020190407767146]) < 1e-8
@assert norm(obs.νγ - [0.0007495521482876466 4.188026899159083e-6 8.595028393969659e-6; 
        4.1880268803062436e-6 0.0007599372708508531 -1.0092121451703577e-5; 
        8.595028373480989e-6 -1.009212147054782e-5 0.0007370698230021235]) < 1e-8

#### Efficiency
Benchmark for efficiency.

In [9]:
bm_obj = @benchmark logl!($obs, $β, $Σ, $L, $σ², true)

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.320 μs (0.00% GC)
  median time:      1.360 μs (0.00% GC)
  mean time:        1.497 μs (0.00% GC)
  maximum time:     8.090 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

In [10]:
clamp(10 / (median(bm_obj).time / 1e3) * 10, 0, 10)

10.0