-
Notifications
You must be signed in to change notification settings - Fork 47
/
gausshermite.jl
92 lines (75 loc) · 2.75 KB
/
gausshermite.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
using StaticArrays, LinearAlgebra
"""
GaussHermiteQuadrature
As described in
* [Gauss-Hermite quadrature on Wikipedia](http://en.wikipedia.org/wiki/Gauss-Hermite_quadrature)
*Gauss-Hermite* quadrature uses a weighted sum of values of `f(x)` at specific `x` values to approximate
```math
\\int_{-\\infty}^\\infty f(x) e^{-x^2} dx
```
An `n`-point rule, as returned by `hermite(n)` from the
[`GaussQuadrature``](https://github.com/billmclean/GaussQuadrature.jl) package provides `n` abscicca
values (i.e. values of `x`) and `n` weights.
As noted in the Wikipedia article, a modified version can be used to evaluate the expectation `E[h(x)]`
with respect to a `Normal(μ, σ)` density as
```julia
using MixedModels
gn5 = GHnorm(5)
μ = 3.
σ = 2.
sum(@. abs2(σ*gn5.z + μ)*gn5.w) # E[X^2] where X ∼ N(μ, σ)
```
For evaluation of the log-likelihood of a GLMM the integral to evaluate for each level of the grouping
factor is approximately Gaussian shaped.
"""
GaussHermiteQuadrature
"""
GaussHermiteNormalized{K}
A struct with 2 SVector{K,Float64} members
- `z`: abscissae for the K-point Gauss-Hermite quadrature rule on the Z scale
- `wt`: Gauss-Hermite weights normalized to sum to unity
"""
struct GaussHermiteNormalized{K}
z::SVector{K, Float64}
w::SVector{K,Float64}
end
function GaussHermiteNormalized(k::Integer)
sytr = SymTridiagonal(zeros(k), sqrt.(1:k-1))
@static if VERSION ≥ v"0.7.0-DEV.5190"
ev = eigen(sytr)
else
ev = eigfact(sytr)
end
w = abs2.(ev.vectors[1,:])
GaussHermiteNormalized(
SVector{k}((ev.values .- reverse(ev.values)) ./ 2),
SVector{k}(normalize((w .+ reverse(w)) ./ 2, 1)))
end
@static if VERSION ≥ v"0.7.0-DEV.5124"
Base.iterate(g::GaussHermiteNormalized{K}, i=1) where {K} =
(K < i ? nothing : ((z = g.z[i], w = g.w[i]), i + 1))
else
Base.start(gh::GaussHermiteNormalized) = 1
Base.next(gh::GaussHermiteNormalized, i) = (gh.z[i], gh.w[i]), i+1
Base.done(gh::GaussHermiteNormalized{K}, i) where {K} = K < i
end
Base.length(g::GaussHermiteNormalized{K}) where {K} = K
"""
GHnormd
Memoized values of `GHnorm`{@ref} stored as a `Dict{Int,GaussHermiteNormalized}`
"""
const GHnormd = Dict{Int,GaussHermiteNormalized}(
1 => GaussHermiteNormalized(SVector{1}(0.),SVector{1}(1.)),
2 => GaussHermiteNormalized(SVector{2}(-1.0,1.0),SVector{2}(0.5,0.5)),
3 => GaussHermiteNormalized(SVector{3}(-sqrt(3),0.,sqrt(3)),SVector{3}(1/6,2/3,1/6))
)
"""
GHnorm(k::Int)
Return the (unique) GaussHermiteNormalized{k} object.
The values are memoized in [`GHnormd`](@ref) when first evaluated. Subsequent evaluations
for the same `k` have very low overhead.
"""
GHnorm(k::Int) = get!(GHnormd, k) do
GaussHermiteNormalized(k)
end
GHnorm(k) = GHnorm(Int(k))