From e67e40a8490762f05235342eb72158dcf6bbf623 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Fri, 7 Dec 2018 20:18:28 -0500 Subject: [PATCH 01/11] ignore ipynb checkpoints --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index bcfd702a..9c9e94f3 100644 --- a/.gitignore +++ b/.gitignore @@ -269,3 +269,4 @@ TSWLatexianTemp* doc/src/img/*.dot.png /doc/src/img/*.dot.svg /.ipynb_checkpoints/* +**/.ipynb_checkpoints From 305a98615388b074ebfcfab19bf9a2c18711ec7b Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Fri, 7 Dec 2018 20:19:33 -0500 Subject: [PATCH 02/11] ignore documentation build directory. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9c9e94f3..4b598840 100644 --- a/.gitignore +++ b/.gitignore @@ -270,3 +270,4 @@ doc/src/img/*.dot.png /doc/src/img/*.dot.svg /.ipynb_checkpoints/* **/.ipynb_checkpoints +/doc/build From 010ede356c9ad3be148bfed5f3acbe766b8fc050 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Sun, 9 Dec 2018 16:49:23 -0500 Subject: [PATCH 03/11] add a TracedRun cassette context --- src/cassette.jl | 176 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 src/cassette.jl diff --git a/src/cassette.jl b/src/cassette.jl new file mode 100644 index 00000000..f92394b3 --- /dev/null +++ b/src/cassette.jl @@ -0,0 +1,176 @@ +module CassTest + +# this recursion based solution is too fancy for me to understand. +using Cassette + +function construct(T::Type, args...) + @info "constructing a model $T" + return T(args...) +end + +Cassette.@context TraceCtx + +function Cassette.execute(ctx::TraceCtx, args...) + subtrace = Any[] + push!(ctx.metadata, args => subtrace) + if Cassette.canoverdub(ctx, args...) + newctx = Cassette.similarcontext(ctx, metadata = subtrace) + return Cassette.overdub(newctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +function Cassette.execute(ctx::TraceCtx, f::typeof(Base.vect), args...) + @info "constructing a vector" + push!(ctx.metadata, (f, args)) + return Cassette.fallback(ctx, f, args...) +end + +function Cassette.execute(ctx::TraceCtx, f::typeof(Core.apply_type), args...) + @info "applying a type $f" + push!(ctx.metadata, (f, args)) + return Cassette.fallback(ctx, f, args...) +end + +# TODO: support calls like construct(T, a, f(b)) +function Cassette.execute(ctx::TraceCtx, f::typeof(construct), args...) + @info "constructing with type $f" + push!(ctx.metadata, (f, args)) + y = Cassette.fallback(ctx, f, args...) + @info "constructed model: $y" + return y +end + +struct TracedRun{T,V} + trace::T + value::V +end + + +function buildtrace(f::Function) + trace = Any[] + val = Cassette.overdub(TraceCtx(metadata=trace), f) + return TracedRun(trace, val) +end + + + +trace = Any[] +x, y, z = rand(3) +f(x, y, z) = x*y + y*z +Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z)) + +# returns `true` +trace == Any[ + (f,x,y,z) => Any[ + (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]] + (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]] + (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]] + ] +] +# using Cassette + +# Cassette.@context TraceCtx + +# mutable struct Trace +# current::Vector{Any} +# stack::Vector{Any} +# Trace() = new(Any[], Any[]) +# end + +# function enter!(t::Trace, args...) +# pair = args => Any[] +# push!(t.current, pair) +# push!(t.stack, t.current) +# t.current = pair.second +# return nothing +# end + +# function exit!(t::Trace) +# t.current = pop!(t.stack) +# return nothing +# end + +# Cassette.prehook(ctx::TraceCtx, args...) = enter!(ctx.metadata, args...) +# Cassette.posthook(ctx::TraceCtx, args...) = exit!(ctx.metadata) + +# trace = Trace() +# x, y, z = rand(3) +# f(x, y, z) = x*y + y*z +# Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z)) + +# # returns `true` +# trace.current == Any[ +# (f,x,y,z) => Any[ +# (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]] +# (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]] +# (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]] +# ] +# ] + +struct ODEProblem{T,U,V,W} + ode::T + init::U + tspan::V + parms::W +end + +# Define == for ODEProblems to mean they are semantically equivalent. +# The normal definition for structs is that equality of fields is done with === +# which means structs that contain arrays, cannot be == if those arrays are not === +# perhaps we should use the AutoHashEquals.jl package here. +function Base.:(==)(p::ODEProblem, q::ODEProblem) + return p.ode == q.ode && p.init == q.init && p.tspan == q.tspan && p.parms == q.parms +end + +""" sir_ode2(du,u,p,t) + +computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. +""" +sir_ode2(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob2 = construct(ODEProblem,sir_ode2,init,tspan,parms) + # sir_sol = solve(sir_prob2,saveat = 0.1) + return sir_prob2 +end + +function h() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob2 = ODEProblem(sir_ode2,init,tspan,parms) + # sir_sol = solve(sir_prob2,saveat = 0.1) + return sir_prob2 +end + +# these should be equal, but we have to write a custom == function for ODEProblem. + +@info "Tracing implementation with construct call" +trace1 = buildtrace(g) +@info "Tracing implementation without construct call" +trace2 = buildtrace(h) +@assert trace2.value == trace1.value + +for i in 1:3 + @assert trace1.trace[i] == trace2.trace[i] +end + +@info "Trace step for construct call" +@show trace1.trace[4] +@assert trace1.trace[4][1] == construct +@info "Trace step for direct apply type" +@show trace2.trace[4] +@assert last(trace2.trace[4])[1][1] == Core.apply_type + +end From aabf5ffefeeb1edd890496774389222df5143920 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Sun, 9 Dec 2018 17:40:36 -0500 Subject: [PATCH 04/11] add Dubstep module and tests --- Project.toml | 1 + src/SemanticModels.jl | 3 +- src/cassette.jl | 125 +----------------------------------------- test/cassette.jl | 103 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 110 insertions(+), 123 deletions(-) create mode 100644 test/cassette.jl diff --git a/Project.toml b/Project.toml index 1a44e63e..e505818b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["James Fairbanks "] version = "0.1.0" [deps] +Cassette = "7057c7e9-c182-5462-911a-8362d720325c" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" diff --git a/src/SemanticModels.jl b/src/SemanticModels.jl index 6bb6299e..5919f23c 100644 --- a/src/SemanticModels.jl +++ b/src/SemanticModels.jl @@ -81,7 +81,8 @@ end include("diffeq.jl") include("regression.jl") -include("grfn.jl") +# include("grfn.jl") +include("cassette.jl") """ CombinedModel diff --git a/src/cassette.jl b/src/cassette.jl index f92394b3..a8656308 100644 --- a/src/cassette.jl +++ b/src/cassette.jl @@ -1,7 +1,7 @@ -module CassTest +module Dubstep -# this recursion based solution is too fancy for me to understand. using Cassette +export construct, TracedRun, trace, TraceCtx function construct(T::Type, args...) @info "constructing a model $T" @@ -48,129 +48,10 @@ struct TracedRun{T,V} end -function buildtrace(f::Function) +function trace(f::Function) trace = Any[] val = Cassette.overdub(TraceCtx(metadata=trace), f) return TracedRun(trace, val) end - - -trace = Any[] -x, y, z = rand(3) -f(x, y, z) = x*y + y*z -Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z)) - -# returns `true` -trace == Any[ - (f,x,y,z) => Any[ - (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]] - (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]] - (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]] - ] -] -# using Cassette - -# Cassette.@context TraceCtx - -# mutable struct Trace -# current::Vector{Any} -# stack::Vector{Any} -# Trace() = new(Any[], Any[]) -# end - -# function enter!(t::Trace, args...) -# pair = args => Any[] -# push!(t.current, pair) -# push!(t.stack, t.current) -# t.current = pair.second -# return nothing -# end - -# function exit!(t::Trace) -# t.current = pop!(t.stack) -# return nothing -# end - -# Cassette.prehook(ctx::TraceCtx, args...) = enter!(ctx.metadata, args...) -# Cassette.posthook(ctx::TraceCtx, args...) = exit!(ctx.metadata) - -# trace = Trace() -# x, y, z = rand(3) -# f(x, y, z) = x*y + y*z -# Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z)) - -# # returns `true` -# trace.current == Any[ -# (f,x,y,z) => Any[ -# (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]] -# (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]] -# (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]] -# ] -# ] - -struct ODEProblem{T,U,V,W} - ode::T - init::U - tspan::V - parms::W -end - -# Define == for ODEProblems to mean they are semantically equivalent. -# The normal definition for structs is that equality of fields is done with === -# which means structs that contain arrays, cannot be == if those arrays are not === -# perhaps we should use the AutoHashEquals.jl package here. -function Base.:(==)(p::ODEProblem, q::ODEProblem) - return p.ode == q.ode && p.init == q.init && p.tspan == q.tspan && p.parms == q.parms -end - -""" sir_ode2(du,u,p,t) - -computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. -""" -sir_ode2(du,u,p,t) = begin - S,I,R = u - b,g = p - du[1] = -b*S*I - du[2] = b*S*I-g*I - du[3] = g*I -end - -function g() - parms = [0.1,0.05] - init = [0.99,0.01,0.0] - tspan = (0.0,200.0) - sir_prob2 = construct(ODEProblem,sir_ode2,init,tspan,parms) - # sir_sol = solve(sir_prob2,saveat = 0.1) - return sir_prob2 -end - -function h() - parms = [0.1,0.05] - init = [0.99,0.01,0.0] - tspan = (0.0,200.0) - sir_prob2 = ODEProblem(sir_ode2,init,tspan,parms) - # sir_sol = solve(sir_prob2,saveat = 0.1) - return sir_prob2 -end - -# these should be equal, but we have to write a custom == function for ODEProblem. - -@info "Tracing implementation with construct call" -trace1 = buildtrace(g) -@info "Tracing implementation without construct call" -trace2 = buildtrace(h) -@assert trace2.value == trace1.value - -for i in 1:3 - @assert trace1.trace[i] == trace2.trace[i] -end - -@info "Trace step for construct call" -@show trace1.trace[4] -@assert trace1.trace[4][1] == construct -@info "Trace step for direct apply type" -@show trace2.trace[4] -@assert last(trace2.trace[4])[1][1] == Core.apply_type - end diff --git a/test/cassette.jl b/test/cassette.jl new file mode 100644 index 00000000..2c1ca96e --- /dev/null +++ b/test/cassette.jl @@ -0,0 +1,103 @@ +module TraceTest + +using Cassette +using Test +using SemanticModels.Dubstep + +struct ODEProblem{T,U,V,W} + ode::T + init::U + tspan::V + parms::W +end + +# Define == for ODEProblems to mean they are semantically equivalent. The +# normal definition for structs is that equality of fields is done with === +# which means structs that contain arrays, cannot be == if those arrays are +# not === perhaps we should use the AutoHashEquals.jl package here. +function Base.:(==)(p::ODEProblem, q::ODEProblem) + return p.ode == q.ode && p.init == q.init && p.tspan == q.tspan && p.parms == q.parms +end + +""" sir_ode2(du,u,p,t) + + computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. + """ +sir_ode2(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +# Cassette.@context TraceCtx + +# function Cassette.execute(ctx::TraceCtx, args...) +# subtrace = Any[] +# push!(ctx.metadata, args => subtrace) +# if Cassette.canoverdub(ctx, args...) +# newctx = Cassette.similarcontext(ctx, metadata = subtrace) +# return Cassette.overdub(newctx, args...) +# else +# return Cassette.fallback(ctx, args...) +# end +# end + +trace = Any[] +x, y, z = rand(3) +f(x, y, z) = x*y + y*z +Cassette.overdub(Dubstep.TraceCtx(metadata = trace), () -> f(x, y, z)) +@testset "Cassette" begin + @test trace == Any[ + (f,x,y,z) => Any[ + (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]] + (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]] + (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]] + ] + ] +end + + +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob2 = construct(ODEProblem,sir_ode2,init,tspan,parms) + # sir_sol = solve(sir_prob2,saveat = 0.1) + return sir_prob2 +end + +function h() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob2 = ODEProblem(sir_ode2,init,tspan,parms) + # sir_sol = solve(sir_prob2,saveat = 0.1) + return sir_prob2 +end + +# these are equal, because we wrote a custom == function for ODEProblem. +@testset "Dubstep" begin + @testset "SIR" begin + + @info "Tracing implementation with construct call" + trace1 = Dubstep.trace(g) + @info "Tracing implementation without construct call" + trace2 = Dubstep.trace(h) + @test trace2.value == trace1.value + + for i in 1:3 + @test trace1.trace[i] == trace2.trace[i] + end + + @info "Trace step for construct call" + @show trace1.trace[4] + @test trace1.trace[4][1] == construct + @info "Trace step for direct apply type" + @show trace2.trace[4] + @test last(trace2.trace[4])[1][1] == Core.apply_type + +end #SIR +end #Dubstep +end #module diff --git a/test/runtests.jl b/test/runtests.jl index 19d04185..d7f8419d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using GLM using DataFrames using Plots +include("cassette.jl") stripunits(x) = uconvert(NoUnits, x) @testset "spring models" begin From 9750e8ac0c207f4d37374f0af9814eaff5e476f4 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Mon, 10 Dec 2018 12:40:08 -0500 Subject: [PATCH 05/11] add cassette docs and contexts --- doc/make.jl | 1 + doc/src/dubstep.md | 25 +++++++++++++++++++++++++ src/cassette.jl | 27 +++++++++++++++++++++++++++ test/cassette.jl | 23 +++++++++++++++++++++++ 4 files changed, 76 insertions(+) create mode 100644 doc/src/dubstep.md diff --git a/doc/make.jl b/doc/make.jl index 409c2960..efba4483 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -28,6 +28,7 @@ pages = Any[ "Approaches" => "approach.md", "Library Reference" => "library.md", "Slides" => "slides.md", + "Dubstep" => "dubstep.md", "Flu Model" => "FluModel.md" # "Model Types" => "types.md", # # "Reading / Writing Models" => "persistence.md", diff --git a/doc/src/dubstep.md b/doc/src/dubstep.md new file mode 100644 index 00000000..2b8dc72c --- /dev/null +++ b/doc/src/dubstep.md @@ -0,0 +1,25 @@ +# Dubstep + +This module uses Cassette.jl to modify programs by overdubbing their executions in a context. + +## TraceCtx + +Builds hierarchical runtime value traces by running the program you pass it. You can change the metadata. +You can change out the metadata that you pass in order to collect different information. The default is Any[]. + +## LPCtx + +Replaces all calls to `norm(x,p)` which `norm(x,ctx.metadata[p])` so you can change the norms that a code uses to +compute. + +## Reference + + +```@autodocs +Modules = [SemanticModels.Dubstep] +``` + +## Index + +```@index +``` diff --git a/src/cassette.jl b/src/cassette.jl index a8656308..206524b2 100644 --- a/src/cassette.jl +++ b/src/cassette.jl @@ -54,4 +54,31 @@ function trace(f::Function) return TracedRun(trace, val) end + + +Cassette.@context LPCtx + +""" LPCtx + +replaces all calls to `LinearAlgebra.norm` with a different `p`. + +This context is useful for modifying statistical codes or machine learning regularizers. +""" +LPCtx + +function Cassette.execute(ctx::LPCtx, args...) + if Cassette.canoverdub(ctx, args...) + newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(newctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +using LinearAlgebra +function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power) + return f(arg, ctx.metadata[power]) +end + end + diff --git a/test/cassette.jl b/test/cassette.jl index 2c1ca96e..e75b9fba 100644 --- a/test/cassette.jl +++ b/test/cassette.jl @@ -3,6 +3,7 @@ module TraceTest using Cassette using Test using SemanticModels.Dubstep +using LinearAlgebra struct ODEProblem{T,U,V,W} ode::T @@ -99,5 +100,27 @@ end @test last(trace2.trace[4])[1][1] == Core.apply_type end #SIR + end #Dubstep + + +subg(x,y) = norm([x x x]/6 - [y y y]/2, 2) +function g() + a = 5+7 + b = 3+4 + c = subg(a,b) + return c +end +function h() + a = 5+7 + c = subg(a,(3+4)) + return c +end + +@testset "LP" begin +@test 2.5980 < g() < 2.599 +ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1)) +@test Cassette.overdub(ctx, g) == 4.5 +end #LP + end #module From ebc49a755c19a04c76884e5f8b8925cdeddf05d1 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Mon, 10 Dec 2018 13:43:06 -0500 Subject: [PATCH 06/11] add docs for replacenorm --- doc/src/dubstep.md | 57 ++++++++++++++++++++++++++++++++++++++++++---- src/cassette.jl | 23 +++++++++++++++++-- 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/doc/src/dubstep.md b/doc/src/dubstep.md index 2b8dc72c..9a234c98 100644 --- a/doc/src/dubstep.md +++ b/doc/src/dubstep.md @@ -12,14 +12,61 @@ You can change out the metadata that you pass in order to collect different info Replaces all calls to `norm(x,p)` which `norm(x,ctx.metadata[p])` so you can change the norms that a code uses to compute. -## Reference +### Example +Here is an example of changing an internal component of a mathematical operation using cassette to rewrite the norm function. -```@autodocs -Modules = [SemanticModels.Dubstep] +First we define a function that uses norm, and +another function that calls it. +```julia +subg(x,y) = norm([x x x]/6 - [y y y]/2, 2) +function g() + a = 5+7 + b = 3+4 + c = subg(a,b) + return c +end +``` + +We use the Dubstep.LPCtx which is shown here. + +```julia +Cassette.@context LPCtx + +function Cassette.execute(ctx::LPCtx, args...) + if Cassette.canoverdub(ctx, args...) + newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(newctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +using LinearAlgebra +function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power) + return f(arg, ctx.metadata[power]) +end ``` -## Index +Note the method definition of `Cassette.execute` +for LPCtx when called with the function +`LinearAlgebra.norm`. -```@index +We then construct an instance of the context that +configures how we want to do the substitution. +```julia +@testset "LP" begin +@test 2.5980 < g() < 2.599 +ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1 +@test Cassette.overdub(ctx, g) == 4.5 +``` + +And just like that, we can control the execution +of a program without rewriting it at the lexical level. + +## Reference + + +```@autodocs +Modules = [SemanticModels.Dubstep] ``` diff --git a/src/cassette.jl b/src/cassette.jl index 206524b2..f0eba967 100644 --- a/src/cassette.jl +++ b/src/cassette.jl @@ -1,7 +1,7 @@ module Dubstep using Cassette -export construct, TracedRun, trace, TraceCtx +export construct, TracedRun, trace, TraceCtx, LPCtx, replacenorm function construct(T::Type, args...) @info "constructing a model $T" @@ -42,12 +42,21 @@ function Cassette.execute(ctx::TraceCtx, f::typeof(construct), args...) return y end +""" TracedRun{T,V} + +captures the dataflow of a code execution. We store the trace and the value. + +see also `trace`. +""" struct TracedRun{T,V} trace::T value::V end +""" trace(f) +run the function f and return a TracedRun containing the trace and the output. +""" function trace(f::Function) trace = Any[] val = Cassette.overdub(TraceCtx(metadata=trace), f) @@ -77,8 +86,18 @@ end using LinearAlgebra function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power) - return f(arg, ctx.metadata[power]) + p = get(ctx.metadata, power, power) + return f(arg, p) end +""" replacenorm(f::Function, d::AbstractDict) + +run f, but replace every call to norm using the mapping in d. +""" +function replacenorm(f::Function, d::AbstractDict) + ctx = LPCtx(metadata=d) + return Cassette.overdub(ctx, f) end +end #module + From 5c3f9af70794c1e1ac2979823672aa0baf64dd33 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Mon, 10 Dec 2018 19:07:33 -0500 Subject: [PATCH 07/11] add a Context to add a lagniappe to the derivative. --- src/cassette.jl | 2 +- test/transform/ode.jl | 116 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 test/transform/ode.jl diff --git a/src/cassette.jl b/src/cassette.jl index f0eba967..0d95da12 100644 --- a/src/cassette.jl +++ b/src/cassette.jl @@ -28,7 +28,7 @@ function Cassette.execute(ctx::TraceCtx, f::typeof(Base.vect), args...) end function Cassette.execute(ctx::TraceCtx, f::typeof(Core.apply_type), args...) - @info "applying a type $f" + # @info "applying a type $(args)" push!(ctx.metadata, (f, args)) return Cassette.fallback(ctx, f, args...) end diff --git a/test/transform/ode.jl b/test/transform/ode.jl new file mode 100644 index 00000000..986303f0 --- /dev/null +++ b/test/transform/ode.jl @@ -0,0 +1,116 @@ +module ODEXform +using DifferentialEquations +using Cassette +using SemanticModels.Dubstep + +Cassette.@context SolverCtx +function Cassette.execute(ctx::SolverCtx, args...) + if Cassette.canoverdub(ctx, args...) + #newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(ctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +function Cassette.execute(ctx::SolverCtx, f::typeof(Base.vect), args...) + @info "constructing a vector length $(length(args))" + return Cassette.fallback(ctx, f, args...) +end + +# We don't need to overdub basic math. this hopefully makes execution faster. +# if these overloads don't actually make it faster, they can be deleted. +function Cassette.execute(ctx::SolverCtx, f::typeof(+), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(-), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(*), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(/), args...) + return Cassette.fallback(ctx, f, args...) +end + + +end #module + +using SemanticModels.Dubstep +using DifferentialEquations +using Cassette +using Test + +""" sir_ode(du,u,p,t) + +computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. +""" +sir_ode(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +function Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...) + y = Cassette.fallback(ctx, f, args...) + # add a lagniappe of infection + extra = args[1][1] * ctx.metadata.factor + push!(ctx.metadata.extras, extra) + args[1][1] += extra + args[1][2] -= extra + return y +end + +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob = Dubstep.construct(ODEProblem,sir_ode,init,tspan,parms) + return sir_prob +end + +function h() + prob = g() + return solve(prob, alg=Vern7()) +end + +#precompile +@time sol1 = h() +#timeit +@time sol1 = h() + +""" perturb(f, factor) + +run the function f with a perturbation specified by factor. +""" +function perturb(f, factor) + t = (factor=factor,extras=Float64[]) + ctx = ODEXform.SolverCtx(metadata = t) + val = Cassette.overdub(ctx, f) + return val, t +end + +traces = Any[] +solns = Any[] +for f in [0.0, 0.01, 0.05, 0.10] + val, t = perturb(h, f) + push!(traces, t) + push!(solns, val) +end + +for (i, s) in enumerate(solns) + @show s(100) + @show traces[i].factor + @show traces[i].extras[5] + @show sum(traces[i].extras)/length(traces[i].extras) +end + +using LinearAlgebra +using Test + +@test norm(sol1(100) .- solns[1](100),2) < 1e-6 +@test norm(sol1(100) .- solns[2](100),2) > 1e-6 +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2) +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2) From 5e3c726da712b60ed4e006994c6cb7d68ed9357b Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Mon, 10 Dec 2018 20:06:37 -0500 Subject: [PATCH 08/11] add transform example to docs and runtests.jl --- doc/src/dubstep.md | 185 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + test/transform/ode.jl | 12 +-- 3 files changed, 194 insertions(+), 5 deletions(-) diff --git a/doc/src/dubstep.md b/doc/src/dubstep.md index 9a234c98..01538999 100644 --- a/doc/src/dubstep.md +++ b/doc/src/dubstep.md @@ -64,6 +64,191 @@ ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1 And just like that, we can control the execution of a program without rewriting it at the lexical level. + +## Transformations + +You can also transform model by executing it in a +context that changes the function calls. +Eventually we will support writing compiler passes +for modifying models at the expression level, but +for now function calls are a good entry point. + +### Example: Perturbations + +This example comes from the unit tests `test/transform/ode.jl`. + +The first step is to define a context for solving +models. + +```julia +module ODEXform +using DifferentialEquations +using Cassette +using SemanticModels.Dubstep + +Cassette.@context SolverCtx +function Cassette.execute(ctx::SolverCtx, args...) + if Cassette.canoverdub(ctx, args...) + #newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(ctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +function Cassette.execute(ctx::SolverCtx, f::typeof(Base.vect), args...) + @info "constructing a vector length $(length(args))" + return Cassette.fallback(ctx, f, args...) +end + +# We don't need to overdub basic math. this hopefully makes execution faster. +# if these overloads don't actually make it faster, they can be deleted. +function Cassette.execute(ctx::SolverCtx, f::typeof(+), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(-), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(*), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(/), args...) + return Cassette.fallback(ctx, f, args...) +end +end #module +``` + +Then we define our RHS of the differential +equation that is `du/dt = sir_ode(du, u, p, t)`. +This function needs to be defined before we define +the method for `Cassette.execute` with the +signature: +`Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...)` +because we need to have the function we want to +overdub defined before we can specify how to +overdub it. + +```julia +using LinearAlgebra +using Test +using Cassette +using DifferentialEquations +using SemanticModels.Dubstep + +""" sir_ode(du,u,p,t) + +computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. +""" +sir_ode(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +function Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...) + y = Cassette.fallback(ctx, f, args...) + # add a lagniappe of infection + extra = args[1][1] * ctx.metadata.factor + push!(ctx.metadata.extras, extra) + args[1][1] += extra + args[1][2] -= extra + return y +end +``` + +The key thing is that we define the execute method +by specifying that we want to execute `sir_ode` +then compute the extra amount (the lagniappe) and +add that extra amount to the `dS/dt`. The SIR +model has an invariant that `dI/dt = -dS/dt + dR/dt` +so we adjust the `dI/dt` accordingly. + +The rest of this code runs the model in the +context. + +```julia +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob = Dubstep.construct(ODEProblem,sir_ode,init,tspan,parms) + return sir_prob +end + +function h() + prob = g() + return solve(prob, alg=Vern7()) +end + +#precompile +@time sol1 = h() +#timeit +@time sol1 = h() +``` + +We define a perturbation function that handles +setting up the context and collecting the results. +Note that we store the extras in the +context.metadata using a modifying operator push!. + +```julia +""" perturb(f, factor) + +run the function f with a perturbation specified by factor. +""" +function perturb(f, factor) + t = (factor=factor,extras=Float64[]) + ctx = ODEXform.SolverCtx(metadata = t) + val = Cassette.overdub(ctx, f) + return val, t +end +``` + +We collect the traces `t` and solutions `s` in +order to quantify the effect of our perturbation +on the answer computed by `solve`. We test to make +sure that the bigger the perturbation, the bigger +the error. + +```julia +traces = Any[] +solns = Any[] +for f in [0.0, 0.01, 0.05, 0.10] + val, t = perturb(h, f) + push!(traces, t) + push!(solns, val) +end + +for (i, s) in enumerate(solns) + @show s(100) + @show traces[i].factor + @show traces[i].extras[5] + @show sum(traces[i].extras)/length(traces[i].extras) +end + +@testset "ODE perturbation" + +@test norm(sol1(100) .- solns[1](100),2) < 1e-6 +@test norm(sol1(100) .- solns[2](100),2) > 1e-6 +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2) +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2) + +end +``` + +This example illustrates how you can use a +Cassette.Context to highjack the execution of a +scientific model in order to change the execution +in a meaningful way. We also see how the execution +allows use to example the sensitivity of the +solution with respect to the derivative. This +technique allows scientists to answer +counterfactual questions about the execution of +codes, such as "what if the model had a slightly +different RHS?" + ## Reference diff --git a/test/runtests.jl b/test/runtests.jl index d7f8419d..9f057389 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ using DataFrames using Plots include("cassette.jl") +include("transform/ode.jl") + stripunits(x) = uconvert(NoUnits, x) @testset "spring models" begin diff --git a/test/transform/ode.jl b/test/transform/ode.jl index 986303f0..c7f2eddd 100644 --- a/test/transform/ode.jl +++ b/test/transform/ode.jl @@ -36,10 +36,11 @@ end end #module -using SemanticModels.Dubstep -using DifferentialEquations -using Cassette +using LinearAlgebra using Test +using Cassette +using DifferentialEquations +using SemanticModels.Dubstep """ sir_ode(du,u,p,t) @@ -107,10 +108,11 @@ for (i, s) in enumerate(solns) @show sum(traces[i].extras)/length(traces[i].extras) end -using LinearAlgebra -using Test +@testset "ODE perturbation" begin @test norm(sol1(100) .- solns[1](100),2) < 1e-6 @test norm(sol1(100) .- solns[2](100),2) > 1e-6 @test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2) @test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2) + +end From c47fd82b5c8c1d55adfc41c4307d6609f4d00101 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Tue, 11 Dec 2018 14:07:58 -0500 Subject: [PATCH 09/11] add a home for examples and document it. --- README.md | 7 ++++++- examples/README.md | 10 ++++++++++ examples/epicookbook/README.md | 22 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 examples/README.md create mode 100644 examples/epicookbook/README.md diff --git a/README.md b/README.md index 25515d17..697cb58f 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,12 @@ There is a docs folder which contains the documentation, including reports sent Documentation is currently published at jpfairbanks.com/doc/aske and jpfairbanks.com/doc/aske/slides.pdf +### Examples + +In addition to the examples in the documentation, there are fully worked out examples in the folder +https://github.com/jpfairbanks/SemanticModels.jl/tree/master/examples/. Each subdirectory represents one self contained +example, starting with `epicookbook`. + ## Concepts This package enables representation of complex and diverse model structure in the type system of julia. This will allow generic programing and API development for these complex models. @@ -54,4 +60,3 @@ You can use the `Extractor` type to pull knowledge elements from an artifact. Th - Code - Model - Paper - diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..753e9ccb --- /dev/null +++ b/examples/README.md @@ -0,0 +1,10 @@ +# Examples + +This folder contains examples of how to use SemanticModels.jl + +Each subfolder contains a README.md and should have the same layout in terms of + + - src/ + - data/ + - notebooks/ + - docs/ diff --git a/examples/epicookbook/README.md b/examples/epicookbook/README.md new file mode 100644 index 00000000..33d33ef8 --- /dev/null +++ b/examples/epicookbook/README.md @@ -0,0 +1,22 @@ +# Epicookbook example + +This folder contains an example of using the SemanticModels.jl package for analyzing scientific codes and building +scientist augmentation programs. + +## Original Sources + +We are grateful to the authors of all these sources, without their contributions to open science, we would not be able +to build this technology. + +- http://epirecip.es/epicookbook/ + +## Motivation + +Epidemiological models are flexible and well specified. The epicookbook represents a secondary text that explains how to +build epidemiological models from the ground up. The models are in increasing complexity and have code associated with +them. These code examples are often written in multiple languages, which makes them a great test case for our software. + +## Getting Started + +Here is how to run the example: + From 6644b65bcb4dec228a57b4bddc031b02aecbe623 Mon Sep 17 00:00:00 2001 From: James Date: Thu, 13 Dec 2018 13:24:31 -0500 Subject: [PATCH 10/11] add source citation --- doc/src/dubstep.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/src/dubstep.md b/doc/src/dubstep.md index 01538999..b4433c31 100644 --- a/doc/src/dubstep.md +++ b/doc/src/dubstep.md @@ -1,6 +1,6 @@ # Dubstep -This module uses Cassette.jl to modify programs by overdubbing their executions in a context. +This module uses [Cassette.jl](github.com/jrevels/Cassette.jl) ([Zenodo](https://zenodo.org/record/1806173)) to modify programs by overdubbing their executions in a context. ## TraceCtx From c6cf278fb9a5f71c1592d0981b8664a0d8446e9e Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Sat, 15 Dec 2018 08:06:34 -0500 Subject: [PATCH 11/11] add LinearAlgebra to the Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index e505818b..77468e94 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"