Skip to content

Commit

Permalink
Allow for custom step function in forward ODE solver
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed May 22, 2024
1 parent 2b87df1 commit 3532b8f
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/forward_solve_ode.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
const AbstractODEProblem = SciMLBase.AbstractODEProblem
const AbstractODEIntegrator = SciMLBase.AbstractODEIntegrator

struct SimulatorODEConfig{F}
obs_to_prob_time::F
struct SimulatorODEConfig{F,T}
stepfunc::F
obs_to_prob_time::T
end

default_time_converter(::AbstractODEProblem) = identity

function SimulatorForwardProblem(prob::AbstractODEProblem, observables::SimulatorObservable...; obs_to_prob_time=default_time_converter(prob))
function SimulatorForwardProblem(
prob::AbstractODEProblem,
observables::SimulatorObservable...;
stepfunc=step!,
obs_to_prob_time=default_time_converter(prob)
)
named_observables = (; map(x -> nameof(x) => x, observables)...)
return SimulatorForwardProblem(prob, named_observables, SimulatorODEConfig(obs_to_prob_time))
return SimulatorForwardProblem(prob, named_observables, SimulatorODEConfig(stepfunc, obs_to_prob_time))
end

"""
Expand Down Expand Up @@ -80,18 +86,18 @@ function CommonSolve.init(
end

function CommonSolve.step!(forward::SimulatorODEForwardSolver)
if forward.step_idx > length(forward.tstops)
return step!(forward.integrator)
end
# extract fields from forward integrator and compute dt
prob = forward.prob
integrator = forward.integrator
forwardstep! = prob.config.stepfunc
t = forward.tstops[forward.step_idx]
dt = adstrip(t - integrator.t)
retval = if dt > 0
# step to next t if dt > 0
step!(integrator, dt, true)
# if there are no more stopping points, just forward to the integrator and return
if forward.step_idx > length(forward.tstops)
return forwardstep!(forward.integrator)
end
# otherwise, evaluate the next step and observables
retval = forwardstep!(integrator, dt, true)
# iterate over observables and update those for which t is a sample point
for obs in prob.observables
if t map(prob.config.obs_to_prob_time, sampletimes(obs))
Expand Down

0 comments on commit 3532b8f

Please sign in to comment.