In [3]:
using Cassette;
using DifferentialEquations;

In [4]:
function main()
    
    # define our ode
    function sir_ode(du, u, p, t)  
        #Infected per-Capita Rate
        β = p[1]
        #Recover per-capita rate
        γ = p[2]
        #Susceptible Individuals
        S = u[1]
        #Infected by Infected Individuals
        I = u[2]

        du[1] = -β * S * I
        du[2] = β * S * I - γ * I
        du[3] = γ * I
    end

    #Pram = (Infected Per Capita Rate, Recover Per Capita Rate)
    pram = [0.1,0.05]
    #Initial Prams = (Susceptible Individuals, Infected by Infected Individuals)
    init = [0.99,0.01,0.0]
    tspan = (0.0,200.0)
    
    # create a var to our problem
    sir_prob = ODEProblem(sir_ode, init, tspan, pram)
    solution = solve(sir_prob)
    
end

main (generic function with 1 method)

In [5]:
""" 
trace_collector(func, args, ret, subtrace)

a structure to hold metadata for recursive type information
"""
mutable struct trace_collector
    func
    args
    ret
    subtrace::Vector{trace_collector}
end

trace_collector

In [6]:
"""    
trace_collect(func, args...)

creates a new trace_collector logging the input argument types and function name. You have to set the `ret` field after you call the function. 
This constructor creates the subtrace field for use in Cassette.similarcontext.
"""
function trace_collect(func, args...)
    return trace_collector(func, typeof.(args), nothing, trace_collector[])
end

trace_collect

In [7]:
# define the 
Cassette.@context typeCtx;

In [8]:
extractor = trace_collector[]

0-element Array{trace_collector,1}

In [9]:
ctx = typeCtx(metadata = extractor)

Cassette.Context{nametype(typeCtx),Array{trace_collector,1},Nothing,getfield(Cassette, Symbol("##PassType#363")),Nothing,Nothing}(nametype(typeCtx)(), trace_collector[], nothing, getfield(Cassette, Symbol("##PassType#363"))(), nothing, nothing)

In [10]:
# add boilerplate for functionality
function Cassette.overdub(ctx::typeCtx, args...)
    c = trace_collect(args...)
    push!(ctx.metadata, c)
    if Cassette.canrecurse(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = c.subtrace)
        z = Cassette.recurse(newctx, args...)
        c.ret = typeof(z)
        return z
    else
        z = Cassette.fallback(ctx, args...)
        c.ret = typeof(z)
        return z
    end
end

In [11]:
function Cassette.canrecurse(ctx::typeCtx,::typeof(ODEProblem),args...)
    return false
end

function Cassette.canrecurse(ctx::typeCtx,::typeof(Base.vect),args...)
    return false
end

In [13]:
Cassette.overdub(ctx,main)

retcode: Success
Interpolation: 3rd order Hermite
t: 16-element Array{Float64,1}:
   0.0                
   0.12810751461600167
   1.4091826607760183 
   5.1680782805989995 
  10.994513112907462  
  18.301871797637222  
  27.798934064171767  
  39.529419388483916  
  54.032659926311176  
  71.0333307422792    
  90.35176885029713   
 110.76614128323645   
 137.20279004817846   
 160.26481040857357   
 195.13090370189352   
 200.0                
u: 16-element Array{Array{Float64,1},1}:
 [0.99, 0.01, 0.0]                 
 [0.989873, 0.010063, 6.42552e-5]  
 [0.988557, 0.0107138, 0.000729462]
 [0.984197, 0.0128634, 0.00293922] 
 [0.975719, 0.0170157, 0.00726505] 
 [0.961344, 0.0239699, 0.0146865]  
 [0.934383, 0.0367076, 0.0289091]  
 [0.883849, 0.0594416, 0.0567092]  
 [0.789657, 0.0972901, 0.113053]   
 [0.643584, 0.14109, 0.215326]     
 [0.478491, 0.15798, 0.363529]     
 [0.354016, 0.131801, 0.514183]    
 [0.26819, 0.0788203, 0.652989]    
 [0.233414, 0.0441549, 0.722431]   
 [0.2

In [16]:
extractor

1-element Array{trace_collector,1}:
 trace_collector(main, (), OrdinaryDiffEq.ODECompositeSolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,getfield(Main, Symbol("#sir_ode#3")),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},CompositeAlgorithm{Tuple{Tsit5,Rosenbrock23{0,false,LinSolveFactorize{typeof(LinearAlgebra.lu!)},DataType}},AutoSwitch{Tsit5,Rosenbrock23{0,false,LinSolveFactorize{typeof(LinearAlgebra.lu!)},DataType},Rational{Int64},Float64}},OrdinaryDiffEq.CompositeInterpolationData{ODEFunction{true,getfield(Main, Symbol("#sir_ode#3")),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.CompositeCache{

In [18]:
for frame in extractor
    println(frame.func, frame.args)
end

main()


In [23]:
function foo(collector::trace_collector)
    println(collector.func, collector.args)
    for frame in collector.subtrace
         foo(frame)
    end
end

foo (generic function with 1 method)

In [24]:
foo(extractor[1])

main()
Base.vect(Float64, Float64)
Base.vect(Float64, Float64, Float64)
tuple(Float64, Float64)
ODEProblem(getfield(Main, Symbol("#sir_ode#3")), Array{Float64,1}, Tuple{Float64,Float64}, Array{Float64,1})
DiffEqBase.solve(ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,getfield(Main, Symbol("#sir_ode#3")),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},)
NamedTuple()
pairs(NamedTuple{(),Tuple{}},)
keys(NamedTuple{(),Tuple{}},)
Base.Iterators.Pairs(NamedTuple{(),Tuple{}}, Tuple{})
tuple(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, typeof(solve), ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,getfield(Main, Symbol("#sir_ode#3")),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem})
DiffEqBase.#solve#425(Ba

Excessive output truncated after 524290 bytes.

.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(/),Tuple{Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}},)
Base.tail(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(/),Tuple{Array{Float64,1},Array{Float64,1}}},Float64},)
Base.argtail(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(/),Tuple{Array{Float64,1},Array{Float64,1}}}, Float64)
Base.Broadcast.preprocess_args(Array{Float64,1}, Tuple{Float64})
getindex(Tuple{Float64}, Int64)
getfield(Tuple{Float64}, Int64, Bool)
Base.Broadcast.preprocess(Array{Float64,1}, Float64)
Base.Broadcast.broadcast_unalias(Array{Float64,1}, Float64)
===(Array{Float64,1}, Float64)
Base.unalias(Array{Float64,1}, Float64)
Base.Broadcast.extrude(Float64,)
tuple(Float64,)
tuple(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(/),Tuple{Base.Broadcast.Extruded{Array{Fl