Skip to content

Commit

Permalink
Refactor Gaussian
Browse files Browse the repository at this point in the history
  • Loading branch information
mschauer committed Sep 26, 2017
1 parent 6473f09 commit 94c1a4c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 40 deletions.
3 changes: 2 additions & 1 deletion docs/src/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ Bridge.outer
CSpline
Bridge.integrate
Bridge.logpdfnormal
Bridge.logpdfnormalprec
Bridge.runmean
Bridge.PSD
Bridge.Gaussian
```

## Online statistics
Expand Down
88 changes: 50 additions & 38 deletions src/gaussian.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,78 @@
# Gaussian
using Distributions
using Base.LinAlg: norm_sqr

import Base: rand
import Distributions: pdf, logpdf
import Distributions: pdf, logpdf, sqmahal
import Base: chol, size

"""
PSD{T}
Simple wrapper for the lower triangular Cholesky root of a positive (semi-)definite element `σ`.
"""
type PSD{T}
σ::T
PSD::T) where {T} = istril(σ) ? new{T}(σ) : throw(ArgumentError("Argument not lower triangular"))
end
chol(P::PSD) = P.σ'

sumlogdiag(a::Float64, d=1) = log(a)
sumlogdiag(A,d) = sum(log.(diag(A)))
sumlogdiag(Σ::Float64, d=1) = log(Σ)
sumlogdiag(Σ,d) = sum(log.(diag(Σ)))
sumlogdiag(J::UniformScaling, d)= log(J.λ)*d

_logdet(A, d) = logdet(A)

_logdet::PSD, d) = 2*sumlogdiag.σ, d)

_logdet(Σ, d) = logdet(Σ)
_logdet(J::UniformScaling, d) = log(J.λ) * d

_symmetric(A) = Symmetric(A)
_symmetric(Σ) = Symmetric(Σ)
_symmetric(J::UniformScaling) = J

import Distributions: logpdf, pdf
mutable struct Gaussian{T}
mu::T
a
sigma
Gaussian{T}(mu, a) where T = new(mu, a, chol(a)')
end
Gaussian(mu::T, a) where {T} = Gaussian{T}(mu, a)
"""
Gaussian(μ, Σ) -> P
rand(P::Gaussian) = P.mu + P.sigma*randn(typeof(P.mu))
rand(P::Gaussian{Vector{T}}) where {T} = P.mu + P.sigma*randn(T, length(P.mu))
function logpdf(P::Gaussian, x)
S = P.sigma
x = x - P.mu
d = length(x)
-((norm(S\x))^2 + 2sumlogdiag(S,d) + d*log(2pi))/2
Gaussian distribution with mean `μ`` and covariance `Σ`. Defines `rand(P)` and `(log-)pdf(P, x)`.
Designed to work with `Number`s, `UniformScaling`s, `StaticArrays` and `PSD`-matrices.
Implementation details: On `Σ` the functions `logdet`, `whiten` and `unwhiten`
(or `chol` as fallback for the latter two) are called.
"""
struct Gaussian{T,S}
μ::T
Σ::S
Gaussian::T, Σ::S) where {T,S} = new{T,S}(μ, Σ)
end
dim(P::Gaussian) = length(P.μ)
whiten::PSD, z) = Σ.σ\z
whiten(Σ, z) = chol(Σ)'\z
whiten::UniformScaling, z) = z/sqrt.λ)
sqmahal(P::Gaussian, x) = norm_sqr(whiten(P.Σ,x - P.μ))

rand(P::Gaussian) = P.μ + chol(P.Σ)'*randn(typeof(P.μ))
rand(P::Gaussian{Vector}) = P.μ + chol(P.Σ)'*randn(T, length(P.μ))

pdf(P::Gaussian,x) = exp(logpdf(P::Gaussian, x))
logpdf(P::Gaussian, x) = -(sqmahal(P,x) + _logdet(P.Σ, dim(P)) + dim(P)*log(2pi))/2
pdf(P::Gaussian, x) = exp(logpdf(P::Gaussian, x))

function Base.LinAlg.chol(u::SDiagonal{N,T}) where T<:Real where N
all(u.diag .>= zero(T)) || error(Base.LinAlg.PosDefException(1))
return SDiagonal(sqrt.(u.diag))
end

"""
logpdfnormal(x, A)
logpdfnormal(x, Σ)
logpdf of centered Gaussian with covariance A
logpdf of centered Gaussian with covariance Σ
"""
function logpdfnormal(x, A)
function logpdfnormal(x, Σ)

S = chol(_symmetric(A))'
S = chol(_symmetric(Σ))'

d = length(x)
-((norm(S\x))^2 + 2sumlogdiag(S,d) + d*log(2pi))/2
end
function logpdfnormal(x::Float64, a)
-(x^2/a + log(a) + log(2pi))/2
function logpdfnormal(x::Float64, Σ)
-(x^2/Σ + log(Σ) + log(2pi))/2
end

"""
logpdfnormalprec(x, A)
logpdf of centered gaussian with precision A
"""
function logpdfnormalprec(x, A)
d = length(x)
-(dot(x, S*x) - _logdet(A, d) + d*log(2pi))/2
end
logpdfnormalprec(x::Float64, a) = -(a*x^2 - log(a) + log(2pi))/2
5 changes: 4 additions & 1 deletion test/VHK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@ t, x = 0.0, v
@test norm(Bridge.b(t, x, Po) - Bridge.bi(1, x, GP)) < 1e-5

@test norm(Bridge.solvebackward!(Bridge.R3(), Bridge._F, SamplePath(tt,zeros(length(tt))), 2.0, ((t,x)->-x)).yy[1] -
2exp(tt[end]-tt[1]))<1e-5
2exp(tt[end]-tt[1]))<1e-5



43 changes: 43 additions & 0 deletions test/gaussian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using Bridge
using Distributions
using Base.Test
using Bridge: Gaussian, PSD
using StaticArrays


μ = rand()
x = rand()
σ = rand()
Σ = σ*σ'

p = pdf(Normal(μ, Σ), x)
@test pdf(Gaussian(μ, Σ), x) p
@test pdf(Gaussian(μ, Σ*I), x) p
@test pdf(Gaussian([μ], [σ]*[σ]'), x) p

@test pdf(Gaussian((@SVector [μ]), @SMatrix [Σ]), @SVector [x]) p

for d in 1: 3
μ = rand(d)
x = rand(d)
σ = tril(rand(d,d))
Σ = σ*σ'
p = pdf(MvNormal(μ, Σ), x)

@test pdf(Gaussian(μ, Σ), x) p
@test pdf(Gaussian(μ, PSD(σ)), x) p
@test pdf(Gaussian(SVector{d}(μ), SMatrix{d,d}(Σ)), x) p
@test pdf(Gaussian(SVector{d}(μ), PSD(SMatrix{d,d}(σ))), x) p
end

for d in 1: 3
μ = rand(d)
x = rand(d)
σ = rand()
Σ = eye(d)*σ^2
p = pdf(MvNormal(μ, Σ), x)

@test pdf(Gaussian(μ, σ^2*I), x) p
@test pdf(Gaussian(SVector{d}(μ), SDiagonal^2*ones(SVector{d}))), x) p
@test pdf(Gaussian(SVector{d}(μ), SMatrix{d,d}(Σ)), x) p
end

0 comments on commit 94c1a4c

Please sign in to comment.