In [1]:
using Cassette;
using DifferentialEquations;

The model that we would like to edit.

In [2]:
function main()
    
    # define our ode
    function sir_ode(du, u, p, t)  
        #Infected per-Capita Rate
        β = p[1]
        #Recover per-capita rate
        γ = p[2]
        #Susceptible Individuals
        S = u[1]
        #Infected by Infected Individuals
        I = u[2]

        du[1] = -β * S * I
        du[2] = β * S * I - γ * I
        du[3] = γ * I
    end

    #Pram = (Infected Per Capita Rate, Recover Per Capita Rate)
    pram = [0.1,0.05]
    #Initial Prams = (Susceptible Individuals, Infected by Infected Individuals)
    init = [0.99,0.01,0.0]
    tspan = (0.0,200.0)
    
    # create a var to our problem
    sir_prob = ODEProblem(sir_ode, init, tspan, pram)
    solution = solve(sir_prob)
    
end

main (generic function with 1 method)

Define a data structure that we will use to collect the edge information we would like to have.

In [3]:
""" 
trace_collector(func, args, ret, subtrace)

a structure to hold metadata for recursive type information
"""
mutable struct trace_collector
    func
    args
    ret
    subtrace::Vector{trace_collector}
end

trace_collector

In [4]:
"""    
trace_collect(func, args...)

creates a new trace_collector logging the input argument types and function name. You have to set the `ret` field after you call the function. 
This constructor creates the subtrace field for use in Cassette.similarcontext.
"""
function trace_collect(func, args...)
    return trace_collector(func, typeof.(args), nothing, trace_collector[])
end

trace_collect

In [2]:
# define the context we will use
ctx = Cassette.@context typCtx()

# add boilerplate for functionality
function Cassette.overdub(ctx::typCtx, args...)
    @show args
    if Cassette.canrecurse(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata)
        return Cassette.recurse(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end
    
function Cassette.canrecurse(ctx::typCtx,::typeof(ODEProblem),args...)
    return false
end

function Cassette.canrecurse(ctx::typCtx,::typeof(Base.vect),args...)
    return false
end
    
function Cassette.overdub(ctx::typCtx,::typeof(ODEProblem),args...)
    #println("ODE Formulation:")
    @show args
    #println((src=typeof.(args[2:end]),dst=nothing,func=typeof(ODEProblem)))
    return ODEProblem(args...) 
end

function Cassette.overdub(ctx::typCtx,::typeof(solve),args...)
    @show args
    #println("Solver:")
    sol = solve(args...)
    #println((src=typeof(ODEProblem),dst=typeof.((sol.t,sol.u)),func=typeof(solve)))
end

In [6]:
Cassette.overdub(typCtx(),main)

args = (main,)
args = (Base.vect, 0.1, 0.05)
args = (Base.vect, 0.99, 0.01, 0.0)
args = (tuple, 0.0, 200.0)
args = (getfield(Main, Symbol("#sir_ode#3"))(), [0.99, 0.01, 0.0], (0.0, 200.0), [0.1, 0.05])
args = ([36mODEProblem[0m with uType [36mArray{Float64,1}[0m and tType [36mFloat64[0m. In-place: [36mtrue[0m
timespan: (0.0, 200.0)
u0: [0.99, 0.01, 0.0],)


retcode: Success
Interpolation: 3rd order Hermite
t: 16-element Array{Float64,1}:
   0.0                
   0.12810751461600167
   1.4091826607760183 
   5.168078275288366  
  10.994513104681563  
  18.30187178378411   
  27.798934045417262  
  39.52941936332723   
  54.03265989500447   
  71.03333070372756   
  90.35176881316265   
 110.76614123722325   
 137.20279000396786   
 160.26481035931778   
 195.1309036063155    
 200.0                
u: 16-element Array{Array{Float64,1},1}:
 [0.99, 0.01, 0.0]                 
 [0.989873, 0.010063, 6.42552e-5]  
 [0.988557, 0.0107138, 0.000729462]
 [0.984197, 0.0128634, 0.00293922] 
 [0.975719, 0.0170157, 0.00726505] 
 [0.961344, 0.0239699, 0.0146865]  
 [0.934383, 0.0367076, 0.0289091]  
 [0.883849, 0.0594416, 0.0567092]  
 [0.789657, 0.0972901, 0.113053]   
 [0.643584, 0.14109, 0.215326]     
 [0.478491, 0.15798, 0.363529]     
 [0.354016, 0.131801, 0.514183]    
 [0.26819, 0.0788203, 0.652989]    
 [0.233414, 0.0441549, 0.722431]   
 [0.2

In [3]:
function test_main()
    
    gen_p() = rand()
    
    function f(du,u,p,t)
        p = gen_p()
        du = u + exp(p)
    end
        
    u0 = 1/2
    tspan = (0.0,1.0)
    prob = ODEProblem(f,u0,tspan)
    sol = solve(prob)
    
end    

test_main (generic function with 1 method)

In [4]:
Cassette.overdub(typCtx(),test_main)

args = (test_main,)
args = (typeof, getfield(Main, Symbol("#gen_p#4"))())
args = (Core.apply_type, getfield(Main, Symbol("#f#5")), getfield(Main, Symbol("#gen_p#4")))
args = (/, 1, 2)
args = (float, 1)
args = (AbstractFloat, 1)
args = (Float64, 1)
args = (sitofp, Float64, 1)
args = (float, 2)
args = (AbstractFloat, 2)
args = (Float64, 2)
args = (sitofp, Float64, 2)
args = (/, 1.0, 2.0)
args = (div_float, 1.0, 2.0)
args = (tuple, 0.0, 1.0)
args = (getfield(Main, Symbol("#f#5")){getfield(Main, Symbol("#gen_p#4"))}(getfield(Main, Symbol("#gen_p#4"))()), 0.5, (0.0, 1.0))
args = ([36mODEProblem[0m with uType [36mFloat64[0m and tType [36mFloat64[0m. In-place: [36mtrue[0m
timespan: (0.0, 1.0)
u0: 0.5,)


MethodError: MethodError: no method matching similar(::Float64)
Closest candidates are:
  similar(!Matched::ZMQ.Message, !Matched::Type{T}, !Matched::Tuple{Vararg{Int64,N}} where N) where T at /home/infvie/.julia/packages/ZMQ/ABGOx/src/message.jl:93
  similar(!Matched::DataStructures.IntSet) at deprecated.jl:53
  similar(!Matched::Sundials.NVector) at /home/infvie/.julia/packages/Sundials/KYRgQ/src/nvector_wrapper.jl:69
  ...

In [8]:
gen_p() = rand()

function f(du,u,p,t)
    du = p*u + exp(gen_p())
end

u0 = 1/2
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
sol = solve(prob)

MethodError: MethodError: no method matching similar(::Float64)
Closest candidates are:
  similar(!Matched::ZMQ.Message, !Matched::Type{T}, !Matched::Tuple{Vararg{Int64,N}} where N) where T at /home/infvie/.julia/packages/ZMQ/ABGOx/src/message.jl:93
  similar(!Matched::DataStructures.IntSet) at deprecated.jl:53
  similar(!Matched::Sundials.NVector) at /home/infvie/.julia/packages/Sundials/KYRgQ/src/nvector_wrapper.jl:69
  ...