# Extracting a Graph
We will first take a script and add the Cassette logic to get the information from the stack trace.

In [1]:
using Cassette;
using DifferentialEquations;

In [2]:
function main()
    
    # define our ode
    function sir_ode(du, u, p, t)  
        #Infected per-Capita Rate
        β = p[1]
        #Recover per-capita rate
        γ = p[2]
        #Susceptible Individuals
        S = u[1]
        #Infected by Infected Individuals
        I = u[2]

        du[1] = -β * S * I
        du[2] = β * S * I - γ * I
        du[3] = γ * I
    end

    #Pram = (Infected Per Capita Rate, Recover Per Capita Rate)
    pram = [0.1,0.05]
    #Initial Prams = (Susceptible Individuals, Infected by Infected Individuals)
    init = [0.99,0.01,0.0]
    tspan = (0.0,200.0)
    
    # create a var to our problem
    sir_prob = ODEProblem(sir_ode, init, tspan, pram)
    solution = solve(sir_prob)
    
end

main (generic function with 1 method)

In [3]:
""" 
trace_collector(func, args, ret, subtrace)

a structure to hold metadata for recursive type information
"""
mutable struct trace_collector
    func
    args
    ret
    subtrace::Vector{trace_collector}
end

trace_collector

In [4]:
"""    
trace_collect(func, args...)

creates a new trace_collector logging the input argument types and function name. You have to set the `ret` field after you call the function. 
This constructor creates the subtrace field for use in Cassette.similarcontext.
"""
function trace_collect(func, args...)
    return trace_collector(func, typeof.(args), nothing, trace_collector[])
end

trace_collect

In [5]:
# define the 
Cassette.@context typeCtx;

In [6]:
extractor = trace_collector[]

0-element Array{trace_collector,1}

In [7]:
ctx = typeCtx(metadata = extractor)

Cassette.Context{nametype(typeCtx),Array{trace_collector,1},Nothing,getfield(Cassette, Symbol("##PassType#363")),Nothing,Nothing}(nametype(typeCtx)(), trace_collector[], nothing, getfield(Cassette, Symbol("##PassType#363"))(), nothing, nothing)

In [8]:
# add boilerplate for functionality
function Cassette.overdub(ctx::typeCtx, args...)
    c = trace_collect(args...)
    push!(ctx.metadata, c)
    if Cassette.canrecurse(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = c.subtrace)
        z = Cassette.recurse(newctx, args...)
        c.ret = typeof(z)
        return z
    else
        z = Cassette.fallback(ctx, args...)
        c.ret = typeof(z)
        return z
    end
end

In [9]:
function Cassette.canrecurse(ctx::typeCtx,::typeof(ODEProblem),args...)
    return false
end

function Cassette.canrecurse(ctx::typeCtx,::typeof(Base.vect),args...)
    return false
end

In [10]:
Cassette.overdub(ctx,main);

# Create a graph

In [11]:
using LightGraphs;
using MetaGraphs;

In [20]:
g = MetaDiGraph()
function display_extractor(collector::trace_collector)
    add_vertex!(g,:name,collector.args)
    add_vertex!(g,:name,collector.ret)
    add_edge!(g,nv(g)-1,nv(g),:name,collector.func)
    for frame in collector.subtrace
         display_extractor(frame)
    end
    return g
end

display_extractor (generic function with 1 method)

In [21]:
mg = display_extractor(extractor[1])

{230592, 115296} directed Int64 metagraph with Float64 weights defined by :weight (default weight 1.0)

# Display the Graph

In [14]:
using GraphPlot;

In [30]:
nodelabels = [get_prop(mg,i,:name) for i=1:nv(mg)];

In [None]:
gplot(mg,nodelabel=nodelabels)