In [1]:
using Cassette;
using DifferentialEquations;

In [63]:
# define the context we will use
ctx = Cassette.@context typCtx

# add boilerplate for functionality
function Cassette.overdub(ctx::typCtx, args...)
    if Cassette.canrecurse(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata)
        return Cassette.recurse(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end
    
function Cassette.canrecurse(ctx::typCtx,::typeof(ODEProblem),args...)
    return false
end

function Cassette.canrecurse(ctx::typCtx,::typeof(Base.vect),args...)
    return false
end
    
function Cassette.overdub(ctx::typCtx,::typeof(ODEProblem),args...)
    println("ODE Formulation:")
    println( (src=typeof(args[2:end]),dst=nothing,func=typeof(ODEProblem)) )
    return ODEProblem(args...) 
end

function Cassette.overdub(ctx::typCtx,::typeof(solve),args...)
    sol = solve(args...)
    println("Solver:")
    println((src=typeof(ODEProblem),dst=typeof((sol.t,sol.u)),func=typeof(solve)))
end

In [64]:
function main()
    
    # define our ode
    function sir_ode(du, u, p, t)  
        #Infected per-Capita Rate
        β = p[1]
        #Recover per-capita rate
        γ = p[2]
        #Susceptible Individuals
        S = u[1]
        #Infected by Infected Individuals
        I = u[2]

        du[1] = -β * S * I
        du[2] = β * S * I - γ * I
        du[3] = γ * I
    end

    #Pram = (Infected Per Capita Rate, Recover Per Capita Rate)
    pram = [0.1,0.05]
    #Initial Prams = (Susceptible Individuals, Infected by Infected Individuals)
    init = [0.99,0.01,0.0]
    tspan = (0.0,200.0)
    
    # create a var to our problem
    sir_prob = ODEProblem(sir_ode, init, tspan, pram)
    solution = solve(sir_prob)
end

main (generic function with 1 method)

In [65]:
Cassette.overdub(typCtx(),main)

ODE Formulation:
(src = Tuple{Array{Float64,1},Tuple{Float64,Float64},Array{Float64,1}}, dst = nothing, func = UnionAll)
Solver:
(src = UnionAll, dst = Tuple{Array{Float64,1},Array{Array{Float64,1},1}}, func = typeof(solve))
