Skip to content

Commit

Permalink
Add sample_prediction method to likelihood interface
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed May 27, 2024
1 parent ccb506e commit ce81dd8
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 23 deletions.
5 changes: 2 additions & 3 deletions ext/SimulationBasedInferenceGenExt/gen_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ function SimulationBasedInference.SimulatorForwardProblem(
gen_fn::GenerativeFunction,
gen_fn_args::Tuple,
initial_params::ComponentVector,
observables::SimulatorObservable...;
retval_fn=trace -> Gen.get_retval(trace),
observables::SimulatorObservable...
)
f = function(θ)
params = copyto!(similar(initial_params), θ)
trace = simulate(gen_fn, (params, gen_fn_args...))
return retval_fn(trace)
return trace
end

return SimulatorForwardProblem(f, initial_params, observables...)
Expand Down
3 changes: 1 addition & 2 deletions ext/SimulationBasedInferenceGenExt/gen_prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function GenSimulatorPrior(gf::Gen.GenerativeFunction, args::Tuple)
end

function (prior::GenSimulatorPrior)(θ::AbstractVector{T}) where {T}
choices = Gen.from_array(prior.choicemap_proto, θ)
choices = Gen.from_array(prior.choicemap_proto, Vector(θ))
trace, _ = generate(prior.gf, prior.args, choices)
return Gen.get_retval(trace)
end
Expand All @@ -35,7 +35,6 @@ SimulationBasedInference.forward_map(prior::GenSimulatorPrior, θ::AbstractVecto
SimulationBasedInference.logprob(prior::GenSimulatorPrior, θ::NamedTuple) = logprob(prior.model, ComponentVector(θ))
function SimulationBasedInference.logprob(prior::GenSimulatorPrior, θ::AbstractVector)
choices = Gen.from_array(prior.choicemap_proto, Vector(θ))
println(choices)
weight, retval = assess(prior.gf, prior.args, choices)
return weight
end
Expand Down
9 changes: 6 additions & 3 deletions ext/pysbi/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ function pysimulator(inference_prob::SimulatorInferenceProblem, transform, pred_
ϕ = SBI.forward_map(inference_prob.prior, θ)
solve(inference_prob.forward_prob, inference_prob.forward_solver, p=ϕ.model)
obs_vecs = map(inference_prob.likelihoods) do lik
dist = SBI.predictive_distribution(lik, ϕ[nameof(lik)])
pred_transform(rand(rng, dist))
# pred_transform(mean(dist))
if hasproperty(ϕ, nameof(lik))
y_pred = SBI.sample_prediction(lik, ϕ[nameof(lik)])
else
y_pred = SBI.sample_prediction(lik)
end
pred_transform(y_pred)
end
if return_py
return Py(reduce(vcat, obs_vecs)).to_numpy()
Expand Down
26 changes: 17 additions & 9 deletions ext/pysbi/pypriors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,24 @@ end

function PyPrior(prior::AbstractSimulatorPrior)
function pylogprob(x)
x = PyArray(x)
x = collect(PyArray(x))
binv = inverse(bijector(prior))
if length(size(x)) == 1
return SBI.logprob(prior, binv(x)) + SBI.logabsdetjac(binv, x)
elseif length(size(x)) == 2
xs = map(binv, eachrow(x))
lp = transpose(reduce(hcat, map(xᵢ -> [SBI.logprob(prior, xᵢ) + SBI.logabsdetjac(binv, xᵢ)], xs)))
return Py(lp).to_numpy()
else
error("invalid sample shape: $(size(x))")
try
if length(size(x)) == 1
return SBI.logprob(prior, binv(x)) + SBI.logabsdetjac(binv, x)
elseif length(size(x)) == 2
xs = map(binv, eachrow(x))
lp = transpose(reduce(hcat, map(xᵢ -> [SBI.logprob(prior, xᵢ) + SBI.logabsdetjac(binv, xᵢ)], xs)))
return Py(lp).to_numpy()
else
error("invalid sample shape: $(size(x))")
end
catch ex
st = stacktrace(catch_backtrace())
@error "$ex on input of type $(typeof(x)) with shape $(size(x))"
showerror(stderr, ex)
show(stderr, "text/plain", st)
rethrow(ex)
end
end

Expand Down
8 changes: 6 additions & 2 deletions src/likelihoods/implicit_likelihood.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
abstract type ImplicitDistribution end

"""
ImplicitLikelihood(
obs,
Expand All @@ -12,6 +14,8 @@ ImplicitLikelihood(
obs,
data,
name=nameof(obs),
) = SimulatorLikelihood(:Implicit, obs, data, nothing, name)
) = SimulatorLikelihood(ImplicitDistribution, obs, data, nothing, name)

predictive_distribution(::SimulatorLikelihood{ImplicitDistribution}) = error("predictive distribution not defined for implicit likelihoods")

predictive_distribution(::SimulatorLikelihood{:Implicit}) = error("predictive distribution not defined for implicit likelihoods")
sample_prediction(rng::AbstractRNG, lik::SimulatorLikelihood{ImplicitDistribution}, args...) = vec(retrieve(lik.obs))
10 changes: 7 additions & 3 deletions src/likelihoods/joint_prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
Constructs a `JointPrior` from the given prior and likelihoods.
"""
function JointPrior(modelprior::AbstractSimulatorPrior, liks::SimulatorLikelihood...)
lik_priors = map(prior, with_names(liks))
lik_priors = with_names(filter(!isnothing, map(prior, liks)))
param_nt = merge(
(model=rand(modelprior),),
map(d -> rand(d), lik_priors),
Expand Down Expand Up @@ -47,8 +47,12 @@ end
function logprob(jp::JointPrior, θ::ComponentVector)
lp_model = logprob(jp.model, θ.model)
liknames = collect(keys(jp.lik))
lp_lik = sum(map((d,n) -> logprob(d, getproperty(θ, n)), collect(jp.lik), liknames))
return lp_model + lp_lik
if length(liknames) > 0
lp_lik = sum(map((d,n) -> logprob(d, getproperty(θ, n)), collect(jp.lik), liknames))
return lp_model + lp_lik
else
return lp_model
end
end
logprob(jp::JointPrior, θ::AbstractVector) = logprob(jp, ComponentVector(θ, jp.ax))

Expand Down
13 changes: 12 additions & 1 deletion src/likelihoods/likelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,22 @@ prior(lik::SimulatorLikelihood) = lik.prior
"""
predictive_distribution(lik::SimulatorLikelihood, args...)
Builds the predictive distribution of `lik` given the parameters in `args`.
Builds the predictive distribution of `lik` given the parameters in `args`. This method is mandatory
for all specializations of `SimulatorLikelihood`.
"""
predictive_distribution(lik::SimulatorLikelihood, args...) = error("not implemented")
predictive_distribution(lik::SimulatorLikelihood, p::NamedTuple) = predictive_distribution(lik, p...)

"""
sample_prediction([rng::AbstractRNG], lik::SimulatorLikelihood, args...)
Samples the conditional predictive distribution `p(y|u)` where `u` is the current value of the likelihood
observable. This method is optional for specializations; the default implementation simply invokes `rand`
on the `predictive_distribution(lik, args...)`.
"""
sample_prediction(lik::SimulatorLikelihood, args...) = sample_prediction(Random.default_rng(), lik, args...)
sample_prediction(rng::AbstractRNG, lik::SimulatorLikelihood, args...) = rand(rng, predictive_distribution(lik, args...))

"""
loglikelihood(lik::SimulatorLikelihood, args...)
Expand Down

0 comments on commit ce81dd8

Please sign in to comment.