diff --git a/.gitignore b/.gitignore index bcfd702a..4b598840 100644 --- a/.gitignore +++ b/.gitignore @@ -269,3 +269,5 @@ TSWLatexianTemp* doc/src/img/*.dot.png /doc/src/img/*.dot.svg /.ipynb_checkpoints/* +**/.ipynb_checkpoints +/doc/build diff --git a/Project.toml b/Project.toml index 1a44e63e..77468e94 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["James Fairbanks "] version = "0.1.0" [deps] +Cassette = "7057c7e9-c182-5462-911a-8362d720325c" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index d29f3a1b..2e8e32f3 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,12 @@ There is a docs folder which contains the documentation, including reports sent Documentation is currently published at jpfairbanks.com/doc/aske and jpfairbanks.com/doc/aske/slides.pdf +### Examples + +In addition to the examples in the documentation, there are fully worked out examples in the folder +https://github.com/jpfairbanks/SemanticModels.jl/tree/master/examples/. Each subdirectory represents one self contained +example, starting with `epicookbook`. + ## Concepts This package enables representation of complex and diverse model structure in the type system of julia. This will allow generic programing and API development for these complex models. @@ -56,4 +62,3 @@ You can use the `Extractor` type to pull knowledge elements from an artifact. Th - Code - Model - Paper - diff --git a/doc/make.jl b/doc/make.jl index 409c2960..efba4483 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -28,6 +28,7 @@ pages = Any[ "Approaches" => "approach.md", "Library Reference" => "library.md", "Slides" => "slides.md", + "Dubstep" => "dubstep.md", "Flu Model" => "FluModel.md" # "Model Types" => "types.md", # # "Reading / Writing Models" => "persistence.md", diff --git a/doc/src/dubstep.md b/doc/src/dubstep.md new file mode 100644 index 00000000..b4433c31 --- /dev/null +++ b/doc/src/dubstep.md @@ -0,0 +1,257 @@ +# Dubstep + +This module uses [Cassette.jl](github.com/jrevels/Cassette.jl) ([Zenodo](https://zenodo.org/record/1806173)) to modify programs by overdubbing their executions in a context. + +## TraceCtx + +Builds hierarchical runtime value traces by running the program you pass it. You can change the metadata. +You can change out the metadata that you pass in order to collect different information. The default is Any[]. + +## LPCtx + +Replaces all calls to `norm(x,p)` which `norm(x,ctx.metadata[p])` so you can change the norms that a code uses to +compute. + +### Example +Here is an example of changing an internal component of a mathematical operation using cassette to rewrite the norm function. + + +First we define a function that uses norm, and +another function that calls it. +```julia +subg(x,y) = norm([x x x]/6 - [y y y]/2, 2) +function g() + a = 5+7 + b = 3+4 + c = subg(a,b) + return c +end +``` + +We use the Dubstep.LPCtx which is shown here. + +```julia +Cassette.@context LPCtx + +function Cassette.execute(ctx::LPCtx, args...) + if Cassette.canoverdub(ctx, args...) + newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(newctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +using LinearAlgebra +function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power) + return f(arg, ctx.metadata[power]) +end +``` + +Note the method definition of `Cassette.execute` +for LPCtx when called with the function +`LinearAlgebra.norm`. + +We then construct an instance of the context that +configures how we want to do the substitution. +```julia +@testset "LP" begin +@test 2.5980 < g() < 2.599 +ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1 +@test Cassette.overdub(ctx, g) == 4.5 +``` + +And just like that, we can control the execution +of a program without rewriting it at the lexical level. + + +## Transformations + +You can also transform model by executing it in a +context that changes the function calls. +Eventually we will support writing compiler passes +for modifying models at the expression level, but +for now function calls are a good entry point. + +### Example: Perturbations + +This example comes from the unit tests `test/transform/ode.jl`. + +The first step is to define a context for solving +models. + +```julia +module ODEXform +using DifferentialEquations +using Cassette +using SemanticModels.Dubstep + +Cassette.@context SolverCtx +function Cassette.execute(ctx::SolverCtx, args...) + if Cassette.canoverdub(ctx, args...) + #newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(ctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +function Cassette.execute(ctx::SolverCtx, f::typeof(Base.vect), args...) + @info "constructing a vector length $(length(args))" + return Cassette.fallback(ctx, f, args...) +end + +# We don't need to overdub basic math. this hopefully makes execution faster. +# if these overloads don't actually make it faster, they can be deleted. +function Cassette.execute(ctx::SolverCtx, f::typeof(+), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(-), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(*), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(/), args...) + return Cassette.fallback(ctx, f, args...) +end +end #module +``` + +Then we define our RHS of the differential +equation that is `du/dt = sir_ode(du, u, p, t)`. +This function needs to be defined before we define +the method for `Cassette.execute` with the +signature: +`Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...)` +because we need to have the function we want to +overdub defined before we can specify how to +overdub it. + +```julia +using LinearAlgebra +using Test +using Cassette +using DifferentialEquations +using SemanticModels.Dubstep + +""" sir_ode(du,u,p,t) + +computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. +""" +sir_ode(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +function Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...) + y = Cassette.fallback(ctx, f, args...) + # add a lagniappe of infection + extra = args[1][1] * ctx.metadata.factor + push!(ctx.metadata.extras, extra) + args[1][1] += extra + args[1][2] -= extra + return y +end +``` + +The key thing is that we define the execute method +by specifying that we want to execute `sir_ode` +then compute the extra amount (the lagniappe) and +add that extra amount to the `dS/dt`. The SIR +model has an invariant that `dI/dt = -dS/dt + dR/dt` +so we adjust the `dI/dt` accordingly. + +The rest of this code runs the model in the +context. + +```julia +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob = Dubstep.construct(ODEProblem,sir_ode,init,tspan,parms) + return sir_prob +end + +function h() + prob = g() + return solve(prob, alg=Vern7()) +end + +#precompile +@time sol1 = h() +#timeit +@time sol1 = h() +``` + +We define a perturbation function that handles +setting up the context and collecting the results. +Note that we store the extras in the +context.metadata using a modifying operator push!. + +```julia +""" perturb(f, factor) + +run the function f with a perturbation specified by factor. +""" +function perturb(f, factor) + t = (factor=factor,extras=Float64[]) + ctx = ODEXform.SolverCtx(metadata = t) + val = Cassette.overdub(ctx, f) + return val, t +end +``` + +We collect the traces `t` and solutions `s` in +order to quantify the effect of our perturbation +on the answer computed by `solve`. We test to make +sure that the bigger the perturbation, the bigger +the error. + +```julia +traces = Any[] +solns = Any[] +for f in [0.0, 0.01, 0.05, 0.10] + val, t = perturb(h, f) + push!(traces, t) + push!(solns, val) +end + +for (i, s) in enumerate(solns) + @show s(100) + @show traces[i].factor + @show traces[i].extras[5] + @show sum(traces[i].extras)/length(traces[i].extras) +end + +@testset "ODE perturbation" + +@test norm(sol1(100) .- solns[1](100),2) < 1e-6 +@test norm(sol1(100) .- solns[2](100),2) > 1e-6 +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2) +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2) + +end +``` + +This example illustrates how you can use a +Cassette.Context to highjack the execution of a +scientific model in order to change the execution +in a meaningful way. We also see how the execution +allows use to example the sensitivity of the +solution with respect to the derivative. This +technique allows scientists to answer +counterfactual questions about the execution of +codes, such as "what if the model had a slightly +different RHS?" + +## Reference + + +```@autodocs +Modules = [SemanticModels.Dubstep] +``` diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..753e9ccb --- /dev/null +++ b/examples/README.md @@ -0,0 +1,10 @@ +# Examples + +This folder contains examples of how to use SemanticModels.jl + +Each subfolder contains a README.md and should have the same layout in terms of + + - src/ + - data/ + - notebooks/ + - docs/ diff --git a/examples/epicookbook/README.md b/examples/epicookbook/README.md new file mode 100644 index 00000000..33d33ef8 --- /dev/null +++ b/examples/epicookbook/README.md @@ -0,0 +1,22 @@ +# Epicookbook example + +This folder contains an example of using the SemanticModels.jl package for analyzing scientific codes and building +scientist augmentation programs. + +## Original Sources + +We are grateful to the authors of all these sources, without their contributions to open science, we would not be able +to build this technology. + +- http://epirecip.es/epicookbook/ + +## Motivation + +Epidemiological models are flexible and well specified. The epicookbook represents a secondary text that explains how to +build epidemiological models from the ground up. The models are in increasing complexity and have code associated with +them. These code examples are often written in multiple languages, which makes them a great test case for our software. + +## Getting Started + +Here is how to run the example: + diff --git a/src/SemanticModels.jl b/src/SemanticModels.jl index f6c06cfb..5919f23c 100644 --- a/src/SemanticModels.jl +++ b/src/SemanticModels.jl @@ -82,6 +82,7 @@ end include("diffeq.jl") include("regression.jl") # include("grfn.jl") +include("cassette.jl") """ CombinedModel diff --git a/src/cassette.jl b/src/cassette.jl new file mode 100644 index 00000000..0d95da12 --- /dev/null +++ b/src/cassette.jl @@ -0,0 +1,103 @@ +module Dubstep + +using Cassette +export construct, TracedRun, trace, TraceCtx, LPCtx, replacenorm + +function construct(T::Type, args...) + @info "constructing a model $T" + return T(args...) +end + +Cassette.@context TraceCtx + +function Cassette.execute(ctx::TraceCtx, args...) + subtrace = Any[] + push!(ctx.metadata, args => subtrace) + if Cassette.canoverdub(ctx, args...) + newctx = Cassette.similarcontext(ctx, metadata = subtrace) + return Cassette.overdub(newctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +function Cassette.execute(ctx::TraceCtx, f::typeof(Base.vect), args...) + @info "constructing a vector" + push!(ctx.metadata, (f, args)) + return Cassette.fallback(ctx, f, args...) +end + +function Cassette.execute(ctx::TraceCtx, f::typeof(Core.apply_type), args...) + # @info "applying a type $(args)" + push!(ctx.metadata, (f, args)) + return Cassette.fallback(ctx, f, args...) +end + +# TODO: support calls like construct(T, a, f(b)) +function Cassette.execute(ctx::TraceCtx, f::typeof(construct), args...) + @info "constructing with type $f" + push!(ctx.metadata, (f, args)) + y = Cassette.fallback(ctx, f, args...) + @info "constructed model: $y" + return y +end + +""" TracedRun{T,V} + +captures the dataflow of a code execution. We store the trace and the value. + +see also `trace`. +""" +struct TracedRun{T,V} + trace::T + value::V +end + +""" trace(f) + +run the function f and return a TracedRun containing the trace and the output. +""" +function trace(f::Function) + trace = Any[] + val = Cassette.overdub(TraceCtx(metadata=trace), f) + return TracedRun(trace, val) +end + + + +Cassette.@context LPCtx + +""" LPCtx + +replaces all calls to `LinearAlgebra.norm` with a different `p`. + +This context is useful for modifying statistical codes or machine learning regularizers. +""" +LPCtx + +function Cassette.execute(ctx::LPCtx, args...) + if Cassette.canoverdub(ctx, args...) + newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(newctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +using LinearAlgebra +function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power) + p = get(ctx.metadata, power, power) + return f(arg, p) +end + +""" replacenorm(f::Function, d::AbstractDict) + +run f, but replace every call to norm using the mapping in d. +""" +function replacenorm(f::Function, d::AbstractDict) + ctx = LPCtx(metadata=d) + return Cassette.overdub(ctx, f) +end + +end #module + diff --git a/test/cassette.jl b/test/cassette.jl new file mode 100644 index 00000000..e75b9fba --- /dev/null +++ b/test/cassette.jl @@ -0,0 +1,126 @@ +module TraceTest + +using Cassette +using Test +using SemanticModels.Dubstep +using LinearAlgebra + +struct ODEProblem{T,U,V,W} + ode::T + init::U + tspan::V + parms::W +end + +# Define == for ODEProblems to mean they are semantically equivalent. The +# normal definition for structs is that equality of fields is done with === +# which means structs that contain arrays, cannot be == if those arrays are +# not === perhaps we should use the AutoHashEquals.jl package here. +function Base.:(==)(p::ODEProblem, q::ODEProblem) + return p.ode == q.ode && p.init == q.init && p.tspan == q.tspan && p.parms == q.parms +end + +""" sir_ode2(du,u,p,t) + + computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. + """ +sir_ode2(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +# Cassette.@context TraceCtx + +# function Cassette.execute(ctx::TraceCtx, args...) +# subtrace = Any[] +# push!(ctx.metadata, args => subtrace) +# if Cassette.canoverdub(ctx, args...) +# newctx = Cassette.similarcontext(ctx, metadata = subtrace) +# return Cassette.overdub(newctx, args...) +# else +# return Cassette.fallback(ctx, args...) +# end +# end + +trace = Any[] +x, y, z = rand(3) +f(x, y, z) = x*y + y*z +Cassette.overdub(Dubstep.TraceCtx(metadata = trace), () -> f(x, y, z)) +@testset "Cassette" begin + @test trace == Any[ + (f,x,y,z) => Any[ + (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]] + (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]] + (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]] + ] + ] +end + + +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob2 = construct(ODEProblem,sir_ode2,init,tspan,parms) + # sir_sol = solve(sir_prob2,saveat = 0.1) + return sir_prob2 +end + +function h() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob2 = ODEProblem(sir_ode2,init,tspan,parms) + # sir_sol = solve(sir_prob2,saveat = 0.1) + return sir_prob2 +end + +# these are equal, because we wrote a custom == function for ODEProblem. +@testset "Dubstep" begin + @testset "SIR" begin + + @info "Tracing implementation with construct call" + trace1 = Dubstep.trace(g) + @info "Tracing implementation without construct call" + trace2 = Dubstep.trace(h) + @test trace2.value == trace1.value + + for i in 1:3 + @test trace1.trace[i] == trace2.trace[i] + end + + @info "Trace step for construct call" + @show trace1.trace[4] + @test trace1.trace[4][1] == construct + @info "Trace step for direct apply type" + @show trace2.trace[4] + @test last(trace2.trace[4])[1][1] == Core.apply_type + +end #SIR + +end #Dubstep + + +subg(x,y) = norm([x x x]/6 - [y y y]/2, 2) +function g() + a = 5+7 + b = 3+4 + c = subg(a,b) + return c +end +function h() + a = 5+7 + c = subg(a,(3+4)) + return c +end + +@testset "LP" begin +@test 2.5980 < g() < 2.599 +ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1)) +@test Cassette.overdub(ctx, g) == 4.5 +end #LP + +end #module diff --git a/test/runtests.jl b/test/runtests.jl index 19d04185..9f057389 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,9 @@ using GLM using DataFrames using Plots +include("cassette.jl") +include("transform/ode.jl") + stripunits(x) = uconvert(NoUnits, x) @testset "spring models" begin diff --git a/test/transform/ode.jl b/test/transform/ode.jl new file mode 100644 index 00000000..c7f2eddd --- /dev/null +++ b/test/transform/ode.jl @@ -0,0 +1,118 @@ +module ODEXform +using DifferentialEquations +using Cassette +using SemanticModels.Dubstep + +Cassette.@context SolverCtx +function Cassette.execute(ctx::SolverCtx, args...) + if Cassette.canoverdub(ctx, args...) + #newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata) + return Cassette.overdub(ctx, args...) + else + return Cassette.fallback(ctx, args...) + end +end + +function Cassette.execute(ctx::SolverCtx, f::typeof(Base.vect), args...) + @info "constructing a vector length $(length(args))" + return Cassette.fallback(ctx, f, args...) +end + +# We don't need to overdub basic math. this hopefully makes execution faster. +# if these overloads don't actually make it faster, they can be deleted. +function Cassette.execute(ctx::SolverCtx, f::typeof(+), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(-), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(*), args...) + return Cassette.fallback(ctx, f, args...) +end +function Cassette.execute(ctx::SolverCtx, f::typeof(/), args...) + return Cassette.fallback(ctx, f, args...) +end + + +end #module + +using LinearAlgebra +using Test +using Cassette +using DifferentialEquations +using SemanticModels.Dubstep + +""" sir_ode(du,u,p,t) + +computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma. +""" +sir_ode(du,u,p,t) = begin + S,I,R = u + b,g = p + du[1] = -b*S*I + du[2] = b*S*I-g*I + du[3] = g*I +end + +function Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...) + y = Cassette.fallback(ctx, f, args...) + # add a lagniappe of infection + extra = args[1][1] * ctx.metadata.factor + push!(ctx.metadata.extras, extra) + args[1][1] += extra + args[1][2] -= extra + return y +end + +function g() + parms = [0.1,0.05] + init = [0.99,0.01,0.0] + tspan = (0.0,200.0) + sir_prob = Dubstep.construct(ODEProblem,sir_ode,init,tspan,parms) + return sir_prob +end + +function h() + prob = g() + return solve(prob, alg=Vern7()) +end + +#precompile +@time sol1 = h() +#timeit +@time sol1 = h() + +""" perturb(f, factor) + +run the function f with a perturbation specified by factor. +""" +function perturb(f, factor) + t = (factor=factor,extras=Float64[]) + ctx = ODEXform.SolverCtx(metadata = t) + val = Cassette.overdub(ctx, f) + return val, t +end + +traces = Any[] +solns = Any[] +for f in [0.0, 0.01, 0.05, 0.10] + val, t = perturb(h, f) + push!(traces, t) + push!(solns, val) +end + +for (i, s) in enumerate(solns) + @show s(100) + @show traces[i].factor + @show traces[i].extras[5] + @show sum(traces[i].extras)/length(traces[i].extras) +end + +@testset "ODE perturbation" begin + +@test norm(sol1(100) .- solns[1](100),2) < 1e-6 +@test norm(sol1(100) .- solns[2](100),2) > 1e-6 +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2) +@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2) + +end