Skip to content

Commit

Permalink
Switch find_first_continuous_callback to use a generated implementation.
Browse files Browse the repository at this point in the history
As mentioned in SciML/DifferentialEquations.jl#971, the current
recursive method for identifying the first continuous callback can cause
the compiler to give up on type inference, especially when there are
many callbacks. The fallback then allocates.

This switches this function to using a generated function (along with an
inline function that takes splatted tuples). Because this generated
function explicitly unrolls the tuple, there are no type inference
problems.

I added a test that allocates using the old implementation (about 19kb
allocations!) but does not with the new system.
  • Loading branch information
meson800 committed Aug 22, 2023
1 parent a2ac2da commit 1cce932
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 82 deletions.
161 changes: 79 additions & 82 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ Recursively apply `initialize!` and return whether any modified u
"""
function initialize!(cb::CallbackSet, u, t, integrator::DEIntegrator)
initialize!(u, t, integrator, false, cb.continuous_callbacks...,
cb.discrete_callbacks...)
cb.discrete_callbacks...)
end
initialize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false
function initialize!(u, t, integrator::DEIntegrator, any_modified::Bool,
c::DECallback, cs::DECallback...)
c::DECallback, cs::DECallback...)
c.initialize(c, u, t, integrator)
initialize!(u, t, integrator, any_modified || integrator.u_modified, cs...)
end
function initialize!(u, t, integrator::DEIntegrator, any_modified::Bool,
c::DECallback)
c::DECallback)
c.initialize(c, u, t, integrator)
any_modified || integrator.u_modified
end
Expand All @@ -29,12 +29,12 @@ function finalize!(cb::CallbackSet, u, t, integrator::DEIntegrator)
end
finalize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false
function finalize!(u, t, integrator::DEIntegrator, any_modified::Bool,
c::DECallback, cs::DECallback...)
c::DECallback, cs::DECallback...)
c.finalize(c, u, t, integrator)
finalize!(u, t, integrator, any_modified || integrator.u_modified, cs...)
end
function finalize!(u, t, integrator::DEIntegrator, any_modified::Bool,
c::DECallback)
c::DECallback)
c.finalize(c, u, t, integrator)
any_modified || integrator.u_modified
end
Expand Down Expand Up @@ -109,56 +109,53 @@ function get_condition(integrator::DEIntegrator, callback, abst)
integrator.sol.stats.ncondition += 1
if callback isa VectorContinuousCallback
callback.condition(@view(integrator.callback_cache.tmp_condition[1:(callback.len)]),
tmp, abst, integrator)
tmp, abst, integrator)
return @view(integrator.callback_cache.tmp_condition[1:(callback.len)])
else
return callback.condition(tmp, abst, integrator)
end
end

# Use Recursion to find the first callback for type-stability

# Base Case: Only one callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback)
(find_callback_time(integrator, callback, 1)..., 1, 1)
end

# Starting Case: Compute on the first callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback,
args...)
find_first_continuous_callback(integrator,
find_callback_time(integrator, callback, 1)..., 1, 1,
args...)
# Use a generated function for type stability even when many callbacks are given
@inline function find_first_continuous_callback(integrator,
callbacks::Vararg{
AbstractContinuousCallback,
N}) where {N}
find_first_continuous_callback(integrator, tuple(callbacks...))
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int,
callback2)
counter += 1 # counter is idx for callback2.
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callback2, counter)

if event_occurred2 && (tmin2 < tmin || !event_occurred)
return tmin2, upcrossing2, true, event_idx2, counter, counter
else
return tmin, upcrossing, event_occurred, event_idx, idx, counter
@generated function find_first_continuous_callback(integrator,
callbacks::NTuple{N,
AbstractContinuousCallback
}) where {N}
ex = quote
tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator,
callbacks[1], 1)
identified_idx = 1
end
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int, callback2, args...)
find_first_continuous_callback(integrator,
find_first_continuous_callback(integrator, tmin,
upcrossing,
event_occurred,
event_idx, idx, counter,
callback2)..., args...)
for i in 2:N
ex = quote
$ex
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callbacks[$i],
$i)
if event_occurred2 && (tmin2 < tmin || !event_occurred)
tmin = tmin2
upcrossing = upcrossing2
event_occurred = true
event_idx = event_idx2
identified_idx = $i
end
end
end
ex = quote
$ex
return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N
end
ex
end

@inline function determine_event_occurance(integrator, callback::VectorContinuousCallback,
counter)
counter)
event_occurred = false
if callback.interp_points != 0
addsteps!(integrator)
Expand All @@ -182,10 +179,10 @@ end

if callback.idxs === nothing
callback.condition(previous_condition, integrator.uprev, integrator.tprev,
integrator)
integrator)
else
callback.condition(previous_condition, integrator.uprev[callback.idxs],
integrator.tprev, integrator)
integrator.tprev, integrator)
end
integrator.sol.stats.ncondition += 1

Expand All @@ -195,7 +192,7 @@ end

if integrator.event_last_time == counter &&
minimum(ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition,
ivec), integrator.t)) <=
ivec), integrator.t)) <=
100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t)

# If there was a previous event, utilize the derivative at the start to
Expand Down Expand Up @@ -226,7 +223,7 @@ end
@. next_sign = sign(next_condition)

event_idx = findall_events!(next_sign, callback.affect!, callback.affect_neg!,
prev_sign)
prev_sign)
if sum(event_idx) != 0
event_occurred = true
interp_index = callback.interp_points
Expand All @@ -239,7 +236,7 @@ end
abst = ts[i]
copyto!(next_sign, get_condition(integrator, callback, abst))
_event_idx = findall_events!(next_sign, callback.affect!, callback.affect_neg!,
prev_sign)
prev_sign)
if sum(_event_idx) != 0
event_occurred = true
event_idx = _event_idx
Expand All @@ -259,7 +256,7 @@ end
next_condition = get_condition(integrator, callback, abst)
@. next_sign = sign(next_condition)
event_idx = findall_events!(next_sign, callback.affect!, callback.affect_neg!,
prev_sign)
prev_sign)
interp_index = callback.interp_points
end
end
Expand All @@ -268,7 +265,7 @@ end
end

@inline function determine_event_occurance(integrator, callback::ContinuousCallback,
counter)
counter)
event_occurred = false
if callback.interp_points != 0
addsteps!(integrator)
Expand All @@ -291,10 +288,10 @@ end
# Check if the event occurred
if callback.idxs === nothing
previous_condition = callback.condition(integrator.uprev, integrator.tprev,
integrator)
integrator)
else
@views previous_condition = callback.condition(integrator.uprev[callback.idxs],
integrator.tprev, integrator)
integrator.tprev, integrator)
end
integrator.sol.stats.ncondition += 1

Expand Down Expand Up @@ -360,15 +357,15 @@ end
# then either f(nextfloat(r)) == 0 or f(nextfloat(r)) * f(r) < 0
# note: not really using bisection - uses the ITP method
function bisection(f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol;
maxiters = 1000)
maxiters = 1000)
if rootfind == SciMLBase.LeftRootFind
solve(IntervalNonlinearProblem{false}(f, tup),
InternalITP(), abstol = abstol,
reltol = reltol).left
InternalITP(), abstol = abstol,
reltol = reltol).left
else
solve(IntervalNonlinearProblem{false}(f, tup),
InternalITP(), abstol = abstol,
reltol = reltol).right
InternalITP(), abstol = abstol,
reltol = reltol).right
end
end

Expand All @@ -379,7 +376,7 @@ Modifies `next_sign` to be an array of booleans for if there is a sign change
in the interval between prev_sign and next_sign
"""
function findall_events!(next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2,
prev_sign::Union{Array, SubArray}) where {F1, F2}
prev_sign::Union{Array, SubArray}) where {F1, F2}
@inbounds for i in 1:length(prev_sign)
next_sign[i] = ((prev_sign[i] < 0 && affect! !== nothing) ||
(prev_sign[i] > 0 && affect_neg! !== nothing)) &&
Expand All @@ -398,8 +395,8 @@ end

function find_callback_time(integrator, callback::ContinuousCallback, counter)
event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance(integrator,
callback,
counter)
callback,
counter)
if event_occurred
if callback.condition === nothing
new_t = zero(typeof(integrator.t))
Expand Down Expand Up @@ -429,7 +426,7 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter)
error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = bisection(zero_func, (bottom_t, top_t), isone(integrator.tdir),
callback.rootfind, callback.abstol, callback.reltol)
callback.rootfind, callback.abstol, callback.reltol)
integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ), Θ)
end
#Θ = prevfloat(...)
Expand Down Expand Up @@ -457,8 +454,8 @@ end

function find_callback_time(integrator, callback::VectorContinuousCallback, counter)
event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance(integrator,
callback,
counter)
callback,
counter)
if event_occurred
if callback.condition === nothing
new_t = zero(typeof(integrator.t))
Expand All @@ -478,8 +475,8 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
if ArrayInterface.allowed_getindex(event_idx, idx) != 0
function zero_func(abst, p = nothing)
ArrayInterface.allowed_getindex(get_condition(integrator,
callback,
abst), idx)
callback,
abst), idx)
end
if zero_func(top_t) == 0
Θ = top_t
Expand All @@ -500,11 +497,11 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
end

Θ = bisection(zero_func, (bottom_t, top_t),
isone(integrator.tdir), callback.rootfind,
callback.abstol, callback.reltol)
isone(integrator.tdir), callback.rootfind,
callback.abstol, callback.reltol)
if integrator.tdir * Θ < integrator.tdir * min_t
integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),
Θ)
Θ)
end
end
if integrator.tdir * Θ < integrator.tdir * min_t
Expand Down Expand Up @@ -545,12 +542,12 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun
end

function apply_callback!(integrator,
callback::Union{ContinuousCallback, VectorContinuousCallback},
cb_time, prev_sign, event_idx)
callback::Union{ContinuousCallback, VectorContinuousCallback},
cb_time, prev_sign, event_idx)
if isadaptive(integrator)
set_proposed_dt!(integrator,
integrator.tdir * max(nextfloat(integrator.opts.dtmin),
integrator.tdir * callback.dtrelax * integrator.dt))
integrator.tdir * max(nextfloat(integrator.opts.dtmin),
integrator.tdir * callback.dtrelax * integrator.dt))
end

change_t_via_interpolation!(integrator, integrator.tprev + cb_time)
Expand Down Expand Up @@ -618,21 +615,21 @@ end
#Starting: Get bool from first and do next
@inline function apply_discrete_callback!(integrator, callback::DiscreteCallback, args...)
apply_discrete_callback!(integrator, apply_discrete_callback!(integrator, callback)...,
args...)
args...)
end

@inline function apply_discrete_callback!(integrator, discrete_modified::Bool,
saved_in_cb::Bool, callback::DiscreteCallback,
args...)
saved_in_cb::Bool, callback::DiscreteCallback,
args...)
bool, saved_in_cb2 = apply_discrete_callback!(integrator,
apply_discrete_callback!(integrator,
callback)...,
args...)
apply_discrete_callback!(integrator,
callback)...,
args...)
discrete_modified || bool, saved_in_cb || saved_in_cb2
end

@inline function apply_discrete_callback!(integrator, discrete_modified::Bool,
saved_in_cb::Bool, callback::DiscreteCallback)
saved_in_cb::Bool, callback::DiscreteCallback)
bool, saved_in_cb2 = apply_discrete_callback!(integrator, callback)
discrete_modified || bool, saved_in_cb || saved_in_cb2
end
Expand Down Expand Up @@ -676,7 +673,7 @@ mutable struct CallbackCache{conditionType, signType}
end

function CallbackCache(u, max_len, ::Type{conditionType},
::Type{signType}) where {conditionType, signType}
::Type{signType}) where {conditionType, signType}
tmp_condition = similar(u, conditionType, max_len)
previous_condition = similar(u, conditionType, max_len)
next_sign = similar(u, signType, max_len)
Expand All @@ -685,7 +682,7 @@ function CallbackCache(u, max_len, ::Type{conditionType},
end

function CallbackCache(max_len, ::Type{conditionType},
::Type{signType}) where {conditionType, signType}
::Type{signType}) where {conditionType, signType}
tmp_condition = zeros(conditionType, max_len)
previous_condition = zeros(conditionType, max_len)
next_sign = zeros(signType, max_len)
Expand Down
Loading

0 comments on commit 1cce932

Please sign in to comment.