Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Update samplers to allow Options parameter.

  • Loading branch information...
commit a11f5260d927961bf7c5d5c4c88fd4786be0f303 1 parent a5e129e
@doobwa authored
Showing with 78 additions and 9 deletions.
  1. +63 −1 src/hmcsampler.jl
  2. +3 −8 src/slicesampler.jl
  3. +12 −0 test/hmc.jl
View
64 src/hmcsampler.jl
@@ -10,7 +10,7 @@ function hmc_sampler(x::Array{Float64}, dd::DifferentiableDensity, params::Array
hmc_sampler(x,dd,params[1],L) # TODO Use a different data structure for params
end
-function hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, epsilon::Float64, L::Int64)
+function hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, epsilon::Float64, L::Int64, )
q = current_q
p = randn(length(q))' # independent standard normal variates (row vector)
@@ -49,3 +49,65 @@ function hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, e
return current_q, current_U # reject
end
end
+
+function bounded_hmc_sampler(x::Array{Float64}, dd::DifferentiableDensity, opts::Options)
+ bounded_hmc_sampler(x,dd.f,dd.gradient,opts)
+end
+
+function bounded_hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, opts::Options)
+
+ @defaults opts stepsize=.01 numsteps=50 bounded=[] lower_bounds=[] upper_bounds=[]
+
+ epsilon = stepsize
+ L = numsteps
+
+ q = current_q
+ p = randn(length(q))' # independent standard normal variates (row vector)
+ current_p = p
+
+ # Make a half step for momentum at the beginning
+ p = p - epsilon * grad_U(q) / 2
+
+ # Alternate full steps for position and momentum
+ for i in 1:L
+ # Make a full step for the position
+ q = q + epsilon * p
+ # Make a full step for the momentum, except at end of trajectory
+ if i!=L
+ p = p - epsilon * grad_U(q)
+ end
+
+ # If out of bounds, reflect
+ for j in bounded
+ if q[j] < lower_bounds[j]
+ q[j] = lower_bounds[j] + (lower_bounds[j] - q[j])
+ p[j] = -p[j]
+ end
+ if q[j] > upper_bounds[j]
+ q[j] = upper_bounds[j] - (q[j] - upper_bounds[j])
+ p[j] = -p[j]
+ end
+ end
+ end
+
+ # Make a half step for momentum at the end.
+ p = p - epsilon * grad_U(q) / 2
+
+ # Negate momentum at end of trajectory to make the proposal symmetric
+ p = -p
+
+ # Evaluate potential and kinetic energies at start and end of trajectory
+ current_U = U(current_q)
+ current_K = sum(current_p.^2) / 2
+ proposed_U = U(q)
+ proposed_K = sum(p.^2) / 2
+
+ @check_used opts
+ # Accept or reject the state at end of trajectory, returning either
+ # the position at the end of the trajectory or the initial position
+ if rand() < exp(current_U-proposed_U+current_K-proposed_K)
+ return q, proposed_U # accept
+ else
+ return current_q, current_U # reject
+ end
+end
View
11 src/slicesampler.jl
@@ -98,12 +98,7 @@ slice_sampler(x0::Float64, d::Density, w::Float64, m::Int64, lower::Float64, upp
slice_sampler(x0::Float64, d::Density, gx0::Float64) = slice_sampler(x0::Float64, d.f, gx0::Float64)
-# TODO: Any easy/useful assertions to make
-function test_slice_sampler()
- function g(x)
- log(dnorm(x,0,.5))
- end
- x0 = 0.0
- x,lp = slice_sampler(x0,g,1.0,10000,-Inf,Inf)
+function slice_sampler(x0::Float64, d::Density, opts::Options)
+ gx0 = g(x0)
+ slice_sampler(x0::Float64, d.f, opts[:w],opts[:m],opts[:lower],opts[:upper],gx0::Float64)
end
-
View
12 test/hmc.jl
@@ -19,8 +19,20 @@ x,gx = hmc_sampler(x0, f, gradient, epsilon, L)
@assert isa(gx,Float64)
# Basic tests on Density interface
+dd = DifferentiableDensity(f,gradient)
x,gx = hmc_sampler(x0, dd, epsilon, L)
@assert isa(x,Array{Float64,2})
@assert isa(gx,Float64)
+# Example of using Options with an HMC sampler allowing bounds
+opts = Options(:stepsize,0.01,
+ :numsteps,int(50),
+ :bounded,[1], # list of dimensions that are bounded
+ :lower_bounds,[0], # lower bound for 1st dim
+ :upper_bounds,[Inf]) # upper bound for 1st dim
+x,gx = bounded_hmc_sampler(x0, dd, opts)
+
+@assert isa(x,Array{Float64,2})
+@assert isa(gx,Float64)
+@assert x[1] > 0
Please sign in to comment.
Something went wrong with that request. Please try again.