In [1]:
using DifferentialEquations
using DifferentialEquations.EnsembleAnalysis
using StatsBase, DataFrames
using Parameters, Plots, BenchmarkTools

### Jump Problem Setup

In [None]:
# L11-23 from src/single_jump_instead_of_many.jl
function affect!(integrator)
  n = rand(1:integrator.p.N)
  n2 = rand(1:integrator.p.N)
  integrator.u[n] = max(integrator.u[n], integrator.u[n2])
end

function μ_SDE(du,u,p,t)
  du .= p.μ
end

function σ_SDE(du,u,p,t)
  du .= p.σ
end

In [None]:
params = @with_kw (
    μ = 0.01, # mean
    σ = 0.1, # drift
    N = 10, # num particles
    β = 0.2, # rate parameter
    t = 0.:0.01:10., # saveat
    moments = Array{Array{Float64, 1}, 1}())

p = params()

In [None]:
x_iv = rand(p.N)  # just draws from the inital condition

prob = SDEProblem(μ_SDE, σ_SDE, x_iv ,(0.0, p.t[end]), p)
rate(u,p,t) = p.β*p.N
jump = ConstantRateJump(rate,affect!)
jump_prob = JumpProblem(prob,Direct(),jump)

### Callback Setup

In [None]:
function save_func(u, t, integrator) 
    if length(integrator.p.moments) == 0 
        g = 0.
    else
        g = (mean(u) - integrator.p.moments[end][2])/step(integrator.p.t)
    end            
    moments = [minimum(u), mean(u), maximum(u), g]
    push!(integrator.p.moments, moments) 
end

In [None]:
cb = FunctionCallingCallback(save_func;
                 funcat=p.t,
                 func_everystep=false,
                 func_start = true,
                 tdir=1)

### Ensemble Setup

In [None]:
function output_func(sol, i)
    resize!(sol.t, 0)
    resize!(sol.u, 0)
    append!(sol.t, sol.prob.p.t) # or something like that... need the `t` values we used for the moments
    append!(sol.u, sol.prob.p.moments)
    return (sol, false)
end

In [None]:
ensemble_prob = EnsembleProblem(prob, output_func = output_func)

### Solve and Plot

In [None]:
sim = solve(ensemble_prob, SRIW1(), EnsembleSerial(), trajectories = 2, callback = cb, save_everystep = false)

In [None]:
plot(sim)

In [None]:
summ = EnsembleSummary(sim)

In [None]:
p1 = plot(summ, idxs = [1], error_style = :none, title = "Min")
p2 = plot(summ, idxs = [2], error_style = :none, title = "Mean")
p3 = plot(summ, idxs = [3], error_style = :none, title = "Max")
p4 = plot(summ, idxs = [4], error_style = :none, title = "Growth")

plot(p1, p2, p3, p4)