Skip to content

Commit

Permalink
Update samplers to allow Options parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
doobwa committed Jun 26, 2012
1 parent a5e129e commit a11f526
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
64 changes: 63 additions & 1 deletion src/hmcsampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function hmc_sampler(x::Array{Float64}, dd::DifferentiableDensity, params::Array
hmc_sampler(x,dd,params[1],L) # TODO Use a different data structure for params
end

function hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, epsilon::Float64, L::Int64)
function hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, epsilon::Float64, L::Int64, )

q = current_q
p = randn(length(q))' # independent standard normal variates (row vector)
Expand Down Expand Up @@ -49,3 +49,65 @@ function hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, e
return current_q, current_U # reject
end
end

function bounded_hmc_sampler(x::Array{Float64}, dd::DifferentiableDensity, opts::Options)
bounded_hmc_sampler(x,dd.f,dd.gradient,opts)
end

function bounded_hmc_sampler(current_q::Array{Float64}, U::Function, grad_U::Function, opts::Options)

@defaults opts stepsize=.01 numsteps=50 bounded=[] lower_bounds=[] upper_bounds=[]

epsilon = stepsize
L = numsteps

q = current_q
p = randn(length(q))' # independent standard normal variates (row vector)
current_p = p

# Make a half step for momentum at the beginning
p = p - epsilon * grad_U(q) / 2

# Alternate full steps for position and momentum
for i in 1:L
# Make a full step for the position
q = q + epsilon * p
# Make a full step for the momentum, except at end of trajectory
if i!=L
p = p - epsilon * grad_U(q)
end

# If out of bounds, reflect
for j in bounded
if q[j] < lower_bounds[j]
q[j] = lower_bounds[j] + (lower_bounds[j] - q[j])
p[j] = -p[j]
end
if q[j] > upper_bounds[j]
q[j] = upper_bounds[j] - (q[j] - upper_bounds[j])
p[j] = -p[j]
end
end
end

# Make a half step for momentum at the end.
p = p - epsilon * grad_U(q) / 2

# Negate momentum at end of trajectory to make the proposal symmetric
p = -p

# Evaluate potential and kinetic energies at start and end of trajectory
current_U = U(current_q)
current_K = sum(current_p.^2) / 2
proposed_U = U(q)
proposed_K = sum(p.^2) / 2

@check_used opts
# Accept or reject the state at end of trajectory, returning either
# the position at the end of the trajectory or the initial position
if rand() < exp(current_U-proposed_U+current_K-proposed_K)
return q, proposed_U # accept
else
return current_q, current_U # reject
end
end
11 changes: 3 additions & 8 deletions src/slicesampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,7 @@ slice_sampler(x0::Float64, d::Density, w::Float64, m::Int64, lower::Float64, upp

slice_sampler(x0::Float64, d::Density, gx0::Float64) = slice_sampler(x0::Float64, d.f, gx0::Float64)

# TODO: Any easy/useful assertions to make
function test_slice_sampler()
function g(x)
log(dnorm(x,0,.5))
end
x0 = 0.0
x,lp = slice_sampler(x0,g,1.0,10000,-Inf,Inf)
function slice_sampler(x0::Float64, d::Density, opts::Options)
gx0 = g(x0)
slice_sampler(x0::Float64, d.f, opts[:w],opts[:m],opts[:lower],opts[:upper],gx0::Float64)
end

12 changes: 12 additions & 0 deletions test/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,20 @@ x,gx = hmc_sampler(x0, f, gradient, epsilon, L)
@assert isa(gx,Float64)

# Basic tests on Density interface
dd = DifferentiableDensity(f,gradient)
x,gx = hmc_sampler(x0, dd, epsilon, L)

@assert isa(x,Array{Float64,2})
@assert isa(gx,Float64)

# Example of using Options with an HMC sampler allowing bounds
opts = Options(:stepsize,0.01,
:numsteps,int(50),
:bounded,[1], # list of dimensions that are bounded
:lower_bounds,[0], # lower bound for 1st dim
:upper_bounds,[Inf]) # upper bound for 1st dim
x,gx = bounded_hmc_sampler(x0, dd, opts)

@assert isa(x,Array{Float64,2})
@assert isa(gx,Float64)
@assert x[1] > 0

0 comments on commit a11f526

Please sign in to comment.