Skip to content

Commit

Permalink
Compute ptilde, refine interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mschauer committed Aug 26, 2017
1 parent c69a6be commit dde46d1
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 24 deletions.
3 changes: 2 additions & 1 deletion docs/src/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ Bridge.Ptilde
```@docs
GuidedProp
Bridge.GuidedBridge
BridgePre
Bridge.Mdb
bridge
bridge!
Bridge.Vs
Bridge.r
Bridge.gpK!
Bridge.gpHinv!
```

## Unsorted
Expand Down
4 changes: 2 additions & 2 deletions src/Bridge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export LinPro, Wiener, WienerBridge, CSpline
export sample, sample!, .., quvar, ito, bracket, lp, llikelihood, transitionprob, girsanov


export BridgeProp, DHBridgeProp, FilterProp, PBridgeProp, GuidedProp, innovations, innovations!, lptilde
export BridgeProp, DHBridgeProp, FilterProp, PBridgeProp, GuidedProp, GuidedBridge, innovations, innovations!, lptilde
export ubridge!, ubridge

export ullikelihood, ullikelihoodtrapez, uinnovations!, ubridge
Expand All @@ -18,7 +18,7 @@ export LevyProcess, GammaProcess, GammaBridge, VarianceGammaProcess, LocalGammaP
export mcstart, mcnext, mcbandmean, mcband

# euler
export SDESolver, Euler, EulerMaruyama, EulerMaruyama!, StochasticHeun,
export SDESolver, Euler, BridgePre, EulerMaruyama, EulerMaruyama!, StochasticHeun,
StochasticRungeKutta, bridge!, bridge, solve, solve!

# ode
Expand Down
33 changes: 30 additions & 3 deletions src/euler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ Euler-Maruyama scheme. `Euler` is defined as alias.
Euler, EulerMaruyama


"""
BridgePre() <: SDESolver
Precomputed Euler-Maruyama scheme for bridges using `bi`.
"""
struct BridgePre <: SDESolver
end

"""
StochasticHeun() <: SDESolver
Expand Down Expand Up @@ -178,10 +186,9 @@ end
Integrate with `method`, where `P is a bridge proposal.
"""
bridge(method::SDESolver, W, P) = bridge!(method, copy(W), W, P)
bridge!(::Euler, Y, W::SamplePath, P::ContinuousTimeProcess) = bridge!(BridgePre(), Y, W, P)



function bridge!(::Euler, Y, W::SamplePath, P::ContinuousTimeProcess{T}) where {T}
function bridge!(::BridgePre, Y, W::SamplePath, P::ContinuousTimeProcess{T}) where {T}
W.tt === P.tt && error("Time axis mismatch between bridge P and driving W.") # not strictly an error

N = length(W)
Expand Down Expand Up @@ -280,6 +287,26 @@ function innovations!(::EulerMaruyama, W, Y::SamplePath, P)
W
end

function innovations!(::BridgePre, W, Y::SamplePath, P)

N = length(W)
N != length(Y) && error("Y and W differ in length.")

yy = Y.yy
tt = Y.tt
ww = W.yy
W.tt[:] = Y.tt

w = zero(ww[.., 1])

for i in 1:N-1
ww[.., i] = w
w = w + σ(tt[i], yy[.., i], P)\(yy[.., i+1] - yy[.., i] - bi(i, yy[.., i], P)*(tt[i+1]-tt[i]))
end
ww[.., N] = w
W
end

function innovations!(::Mdb, W, Y::SamplePath, P)

N = length(W)
Expand Down
20 changes: 15 additions & 5 deletions src/guip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct GuidedBridge{T,S,R2,R} <: ContinuousTimeProcess{T}
v::Tuple{T,T}
K::Vector{S}
V::Vector{T}
lp::Float64
"""
GuidedBridge(tt, (u, v), P, Pt)
Expand All @@ -115,9 +116,20 @@ the time grid `tt` using guiding term derived from linear process `Pt`.
S = typeof(Bridge.outer(zero(T)))
K = SamplePath(tt, zeros(S, N))
V = SamplePath(tt, zeros(T, N))
gpK!(K, Pt)
gpHinv!(K, Pt)
gpV!(V, v[2], Pt)
new{T,S,R2,R}(P, Pt, tt, v, K.yy, V.yy)
lp = logpdfnormal(v[2] - gpmu(tt, v[1], Pt), gpK(tt, zero(S), Pt))
new{T,S,R2,R}(P, Pt, tt, v, K.yy, V.yy, lp)
end
function GuidedBridge(tt_, v::Tuple{T,T}, P::R, Pt::R2, KT) where {T,R,R2}
tt = collect(tt_)
N = length(tt)
S = typeof(Bridge.outer(zero(T)))
K = SamplePath(tt, zeros(S, N))
V = SamplePath(tt, zeros(T, N))
gpHinv!(K, Pt, KT)
gpV!(V, v[2], Pt)
new{T,S,R2,R}(P, Pt, tt, v, K.yy, V.yy, NaN)
end
end

Expand All @@ -129,9 +141,7 @@ a(t, x, P::GuidedBridge) = a(t, x, P.Target)
constdiff(P::GuidedBridge) = constdiff(P.Target) && constdiff(P.Pt)
btilde(t, x, P::GuidedBridge) = b(t, x, P.Pt)
atilde(t, x, P::GuidedBridge) = a(t, x, P.Pt)
function lptilde(P::GuidedBridge)
lp(P.tt[1], P.v[1], P.tt[end], P.v[end], P.Pt)
end
lptilde(P::GuidedBridge) = P.lp



Expand Down
6 changes: 5 additions & 1 deletion src/linpro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ mutable struct Ptilde{T} <: ContinuousTimeProcess{T}
Ptilde{T}(cs, σ) where T = new(cs, σ, σ*σ', inv*σ'))
end
b(t, x, P::Ptilde) = P.cs(t)

B(t, P::Ptilde) = 0.0
β(t, P::Ptilde) = P.cs(t)

mu(s, x, t, P::Ptilde) = x + integrate(P.cs, s, t)
σ(t, x, P::Ptilde) = P.σ
a(t, x, P::Ptilde) = P.a
a(t, P::Ptilde) = P.a
gamma(t, x, P::Ptilde) = P.Γ
constdiff(::Ptilde) = true

Expand Down Expand Up @@ -111,7 +116,6 @@ function V(t, T, v, P::LinPro)
phim*(v - P.μ) + P.μ
end


function dotV(t, T, v, P::LinPro)
expm(-(T-t)*P.B)*P.B*(v - P.μ)
end
Expand Down
9 changes: 9 additions & 0 deletions src/mclog.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ function mcnext(mc, x)
m, m2, n
end

function mcnext(mc, x::Vector{<:AbstractArray}) # fix me: use covariance
m, m2, n = mc
delta = x - m
n = n + 1
m = m + delta*(1/n)
m2 = m2 + map((x,y)->x.*y, delta, (x - m))
m, m2, n
end

"""
mcmeanband(mc)
Expand Down
36 changes: 33 additions & 3 deletions src/ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,24 @@ function kernelbs3(f, t, y, dt, P, k = f(t, y, P))
yº, k4, err
end

@inline _dK(t, K, P) = B(t, P)*K + K*B(t, P)' - a(t, P)
@inline _dHinv(t, K, P) = B(t, P)*K + K*B(t, P)' - a(t, P)

@inline _dK(t, K, P) = B(t, P)*K + K*B(t, P)' + a(t, P)

"""
gpK!(K::SamplePath, P)
gpHinv!(K::SamplePath, P, v=zero(T))
Precompute ``K = H^{-1}`` from ``(d/dt)K = BK + KB' + a`` for a guided proposal.
"""
gpK!(K::SamplePath{T}, P) where {T} = _solvebackward!(R3(), _dK, K, zero(T), P)
gpHinv!(K::SamplePath{T}, P, v=zero(T)) where {T} = _solvebackward!(R3(), _dHinv, K, v, P)
gpV!(V::SamplePath{T}, v::T, P) where {T} = _solvebackward!(R3(), _F, V, v, P)


gpmu(tt, u::T, P) where {T} = solve(R3(), _F, tt, u, P)
gpK(tt, u::T, P) where {T} = solve(R3(), _dK, tt, u, P)



function solvebackward!(method, F, X, xT, P)
_solvebackward!(method, F, X, xT, P)
end
Expand Down Expand Up @@ -106,6 +114,16 @@ end
X
end

function solve(::R3, F, tt, x0::T, P) where {T}
y::T = x0
for i in 2:length(tt)
y = kernelr3(F, tt[i-1], y, tt[i] - tt[i-1], P)
end
y
end



solve!(method::ODESolver, X, x0, F::Function) = solve!(method, _F, X, x0, F)
solve!(method::ODESolver, X, x0, P) = solve!(method, b, X, x0, P)

Expand All @@ -126,3 +144,15 @@ solve!(method::ODESolver, X, x0, P) = solve!(method, b, X, x0, P)
X, err
end

@inline function solve(::BS3, F, tt, x0::T, P) where {T}
0 < length(tt) || throw(ArgumentError("length(X) == 0"))
y::T = x0
length(tt) == 1 && return y, 0.0
y, k, e = kernelbs3(F, tt[1], y, tt[2] - tt[1], P)
err = norm(e, 1)
for i in 3:length(tt)
y, k, e = kernelbs3(F, tt[i-1], y, tt[i] - tt[i-1], P, k)
err = err + norm(e, 1)
end
y, err
end
1 change: 1 addition & 0 deletions test/VHK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ t, x = 0.0, v
@test norm(Bridge.b(t, x, Po) - Bridge.b(t, x, Ptarget) - Bridge.a(t, x, Ptarget)*Bridge.r(t, x, T, v, Pt)) < 1e-10
@test norm(Bridge.b(t, x, Po) - Bridge.b(t, x, Ptarget) - a*(GP.K[1]\(GP.V[1] - x))) < 1e-5

@test norm(GP.lp - Bridge.lp(t, u, T, v, Pt)) < 1e-5


S = SVector{1, Float64}
Expand Down
11 changes: 6 additions & 5 deletions test/linpro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ dt = 1e-6
n2 = 150
tt = linspace(t, T, n2)
K = SamplePath(tt, zeros(M, length(tt)))
Bridge.gpK!(K, P)
Bridge.gpHinv!(K, P)
V = SamplePath(tt, zeros(S, length(tt)))
Bridge.gpK!(K, P) # warm up
Bridge.gpHinv!(K, P) # warm up
Bridge.gpV!(V, v, P)
Mu = SamplePath(tt, zeros(S, length(tt)))
Mu2 = SamplePath(tt, zeros(S, length(tt)))

solve!(Bridge.R3(), Bridge._F, Mu, u, P)
solve!(BS3(), Bridge._F, Mu2, u, P)
@test (@allocated Bridge.gpK!(K, P)) == 0
@test (@allocated Bridge.gpHinv!(K, P)) == 0
@test (@allocated Bridge.gpV!(V, v, P)) == 0
@test norm(K.yy[1]*Bridge.H(t, T, P) - I) < 10/n2^3
@test norm(V.yy[1] - Bridge.V(t, T, v, P)) < 10/n2^3
Expand All @@ -56,6 +56,7 @@ solve!(BS3(), Bridge._F, Mu2, u, P)
@test norm(Mu.yy[end] - Bridge.mu(t, u, T, P)) < 10/n2^3
@test norm(Mu2.yy[end] - Bridge.mu(t, u, T, P)) < 10/n2^3

@test norm(Bridge.K(t, T, P) - solve(Bridge.R3(), Bridge._dK, tt, zero(M), P)) < 10/n2^3

# Normal(mu, lambda) is the stationary distribution. check by starting in stationary distribution and evolve 20 time units
X = Bridge.mat(S[solve(EulerMaruyama(), mu + chol(P.lambda)*randn(S), sample(tt, Wiener{S}()),P).yy[end] - mu for i in 1:m])
Expand All @@ -77,6 +78,6 @@ theta = 0.7
end


P = Bridge.Ptilde(Bridge.CSpline(tt[1], tt[end], 1.0, 0.0, 0.0, 1.0), sigma)
Pt = Bridge.Ptilde(Bridge.CSpline(tt[1], tt[end], 1.0, 0.0, 0.0, 1.0), sigma)

@test Bridge.gamma(t, v, P) inv(sigma*sigma')
@test Bridge.gamma(t, v, Pt) inv(sigma*sigma')
4 changes: 2 additions & 2 deletions test/perf/runbench.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tt = linspace(t, T, n2)
K = SamplePath(tt, zeros(length(tt)))
V = SamplePath(tt, zeros(length(tt)))
Mu = SamplePath(tt, zeros(length(tt)))
suite["solver"]["gpK!"] = @benchmarkable Bridge.gpK!(K, P)
suite["solver"]["gpK!"] = @benchmarkable Bridge.gpHinv!(K, P)
suite["solver"]["gpV!"] = @benchmarkable Bridge.gpV!(V, v, P)
suite["solver"]["solve!(::R3, ...)"] = @benchmarkable Bridge.solve!(Bridge.R3(), Bridge._F, Mu, u, P)
suite["solver"]["solve!(::BS3, ...)"] = @benchmarkable Bridge.solve!(Bridge.BS3(), Bridge._F, Mu, u, P)
Expand Down Expand Up @@ -75,7 +75,7 @@ tt = linspace(t, T, n2)
K = SamplePath(tt, zeros(SM, length(tt)))
V = SamplePath(tt, zeros(SV, length(tt)))
Mu = SamplePath(tt, zeros(SV, length(tt)))
suite["solverSA"]["gpK!"] = @benchmarkable Bridge.gpK!(K, P)
suite["solverSA"]["gpK!"] = @benchmarkable Bridge.gpHinv!(K, P)
suite["solverSA"]["gpV!"] = @benchmarkable Bridge.gpV!(V, v, P)
suite["solverSA"]["solve!(::R3, ...)"] = @benchmarkable Bridge.solve!(Bridge.R3(), Bridge._F, Mu, u, P)
suite["solverSA"]["solve!(::BS3, ...)"] = @benchmarkable Bridge.solve!(Bridge.BS3(), Bridge._F, Mu, u, P)
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include(joinpath("..", "docs", "make.jl"))
include(joinpath("..", "docs", "make.jl")) # this may change rng state

srand(joinpath(@__DIR__,"SEED"),1)

Expand All @@ -12,4 +12,4 @@ include("linpro.jl")
include("timechange.jl")
include("uniformscaling.jl")
include("gamma.jl")
include("with_srand.jl") #run last
include("with_srand.jl") # run last

0 comments on commit dde46d1

Please sign in to comment.