Skip to content


Merge 71926d2 into b328c0b
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Jul 27, 2018
2 parents b328c0b + 71926d2 commit db83720
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 210 deletions.
11 changes: 0 additions & 11 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,3 @@ function withtagfor(context::Context, f)

nametype(::Type{<:Context{N}}) where {N} = N

# `Fallback` #

struct Fallback{F,C<:Context}

(f::Fallback)(args...) = error("Cassette.Fallback($(p.func), $(p.context)) can only be executed in a context with the same type as $(p.context)")
7 changes: 3 additions & 4 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ macro context(Ctx)
$Ctx(; kwargs...) = $Cassette.Context($CtxName(); kwargs...)

@inline $Cassette.execute(::C, ::$Typ($Cassette.Tag), ::Type{N}, ::Type{X}) where {C<:$Ctx,N,X} = $Cassette.Tag(N, X, $Cassette.tagtype(C))
@inline $Cassette.execute(ctx::C, f::$Cassette.Fallback{F,C}, args...) where {F,C<:$Ctx} = $Cassette.fallback(ctx, f.func, args...)

# TODO: There are certain non-`Core.Builtin` functions which the compiler often
# relies upon constant propagation to infer, such as `isdispatchtuple`. Such
Expand Down Expand Up @@ -73,10 +72,10 @@ end
Cassette.@overdub(ctx, expression)
A convenience macro for executing `expression` within the context `ctx`. This macro roughly
expands to `Cassette.recurse(ctx, () -> expression)`.
expands to `Cassette.overdub(ctx, () -> expression)`.
macro overdub(ctx, expr)
return :($Cassette.recurse($(esc(ctx)), () -> $(esc(expr))))
return :($Cassette.overdub($(esc(ctx)), () -> $(esc(expr))))

Expand Down Expand Up @@ -109,7 +108,7 @@ macro pass(transform)
return esc(quote
struct $Pass <: $Cassette.AbstractPass end
(::Type{$Pass})(ctxtype, signature, codeinfo) = $transform(ctxtype, signature, codeinfo)
Core.eval($Cassette, $Cassette.recurse_definition($name, $line, $file))
Core.eval($Cassette, $Cassette.overdub_definition($name, $line, $file))
241 changes: 141 additions & 100 deletions src/overdub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

@inline posthook(::Context, ::Vararg{Any}) = nothing

struct RecurseInstead end
@inline execute(ctx::Context, args...) = RecurseInstead()
struct OverdubInstead end
@inline execute(ctx::Context, args...) = OverdubInstead()

@inline fallback(ctx::Context, args...) = call(ctx, args...)

Expand All @@ -21,79 +21,62 @@ struct RecurseInstead end
@inline call(::ContextWithTag{Nothing}, f::typeof(Core.apply_type), ::Type{A}, ::Type{B}) where {A,B} = f(A, B)
@inline call(::Context, f::typeof(Core.apply_type), ::Type{A}, ::Type{B}) where {A,B} = f(A, B)

@inline canrecurse(ctx::Context, f, args...) = !isa(untag(f, ctx), Core.Builtin)
@inline canoverdub(ctx::Context, f, args...) = !isa(untag(f, ctx), Core.Builtin)

# overdub #

# An alternative approach is to define `execute(args...) = recurse(args...)` by default,
# instead of using the `RecurseInstead` sentinel type. While cleaner, that approach triggers
# the compiler's recursion limiting heuristic. This sentinel type + control flow approach
# avoids that recursion, and thus avoids related inference problems.
@inline function overdub(ctx::Context, args...)
prehook(ctx, args...)
output = execute(ctx, args...)
output = isa(output, RecurseInstead) ? recurse(ctx, args...) : output
posthook(ctx, output, args...)
return output

# This is essentially implementing:
# function overdub(ctx::Context, ::typeof(Core._apply), f, args...)
# return overdub(ctx, f, apply_args(ctx, args...)...)
# end
# but the extra indirection of calling `overdub` there triggers the compiler's recursion
# limiting heuristic. This implementation avoids that problem by manually inlining the
# above expression by a single call level.
@inline function overdub(ctx::Context, ::typeof(Core._apply), f, _args...)
args = apply_args(ctx, _args...)
prehook(ctx, f, args...)
output = execute(ctx, f, args...)
output = isa(output, RecurseInstead) ? recurse(ctx, f, args...) : output
posthook(ctx, output, f, args...)
return output

@inline apply_args(::ContextWithTag{Nothing}, args...) = Core._apply(Core.tuple, args...)
@inline apply_args(ctx::Context, args...) = tagged_apply_args(ctx, args...)

# recurse #
const OVERDUB_CTX_SYMBOL = gensym("overdub_context")
const OVERDUB_ARGS_SYMBOL = gensym("overdub_arguments")
const OVERDUB_TMP_SYMBOL = gensym("overdub_tmp")

const RECURSE_CTX_SYMBOL = gensym("recurse_context")
const RECURSE_ARGS_SYMBOL = gensym("recurse_arguments")

# The `recurse` pass has four intertwined tasks:
# The `overdub` pass has four intertwined tasks:
# 1. Apply the user-provided pass, if one is given
# 2. Munge the reflection-generated IR into a valid form for returning from
# `recurse_generator` (i.e. add new argument slots, substitute static
# parameters, destructure overdub arguments into underlying method slots, etc.)
# 3. Translate all function calls to `overdub` calls
# 3. Replace all calls of the form `output = f(args...)` with:
# ```
# prehook(ctx, f, args...)
# tmp = execute(ctx, f, args...)
# isa(tmp, OverdubInstead) ? overdub(ctx, f, args...) : tmp
# posthook(ctx, f, args...)
# output = tmp
# ```
# 4. If tagging is enabled, do the necessary IR transforms for the metadata tagging system
function recurse_pass!(reflection::Reflection,
function overdub_pass!(reflection::Reflection,
pass_type::DataType = NoPass)
signature = reflection.signature
method = reflection.method
static_params = reflection.static_params
code_info = reflection.code_info

#=== 1. Execute user-provided pass (is a no-op by default) ===#
# TODO: This `iskwfunc` is part of a hack that `overdub_pass!` implements in order to fix
# jrevels/Cassette.jl#48. These assumptions made by this hack are quite fragile, so we
# should eventually get Base to expose a standard/documented API for this. Here, we see
# this hack's first assumption: that `Core.kwfunc(f)` is going to return a function whose
# type name is prefixed by `#kw##`. More assumptions for this hack will be commented on
# as we go.
iskwfunc = startswith(String(signature.parameters[1], "#kw##")
istaggingenabled = has_tagging_enabled(context_type)

code_info = pass_type(context_type, signature, code_info)
#=== execute user-provided pass (is a no-op by default) ===#

#=== 2. Munge the code into a valid form for `recurse_generator` ===#
if !iskwfunc
code_info = pass_type(context_type, signature, code_info)

#=== munge the code into a valid form for `overdub_generator` ===#

# construct new slotnames/slotflags for added slots
code_info.slotnames = Any[:recurse, RECURSE_CTX_SYMBOL, RECURSE_ARGS_SYMBOL, code_info.slotnames...]
code_info.slotflags = UInt8[0x00, 0x00, 0x00, code_info.slotflags...]
n_overdub_slots = 3
code_info.slotnames = Any[:overdub, OVERDUB_CTX_SYMBOL, OVERDUB_ARGS_SYMBOL, code_info.slotnames..., OVERDUB_TMP_SYMBOL]
code_info.slotflags = UInt8[0x00, 0x00, 0x00, code_info.slotflags..., 0x00]
n_prepended_slots = 3
overdub_ctx_slot = SlotNumber(2)
overdub_args_slot = SlotNumber(3)
overdub_tmp_slot = SlotNumber(length(code_info.slotnames))

# For the sake of convenience, the rest of this pass will translate `code_info`'s fields
# into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
Expand All @@ -105,7 +88,7 @@ function recurse_pass!(reflection::Reflection,
n_actual_args = fieldcount(signature)
n_method_args = Int(method.nargs)
for i in 1:n_method_args
slot = i + n_overdub_slots
slot = i + n_prepended_slots
actual_argument = Expr(:call, GlobalRef(Core, :getfield), overdub_args_slot, i)
push!(overdubbed_code, :($(SlotNumber(slot)) = $actual_argument))
push!(overdubbed_codelocs, code_info.codelocs[1])
Expand All @@ -130,64 +113,111 @@ function recurse_pass!(reflection::Reflection,
push!(overdubbed_codelocs, code_info.codelocs[1])
push!(trailing_arguments.args, SSAValue(length(overdubbed_code)))
push!(overdubbed_code, Expr(:(=), SlotNumber(n_method_args + n_overdub_slots), trailing_arguments))
push!(overdubbed_code, Expr(:(=), SlotNumber(n_method_args + n_prepended_slots), trailing_arguments))
push!(overdubbed_codelocs, code_info.codelocs[1])

#=== 3. Translate function calls to `overdub` calls ===#
original_arg_slots = [SlotNumber(i + n_prepended_slots) for i in 1:n_method_args]

#=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===#

# substitute static parameters, offset slot numbers by number of added slots, and
# offset statement indices by the number of additional statements
Base.Meta.partially_inline!(code_info.code, Any[], method.sig, static_params,
n_overdub_slots, length(overdubbed_code), :propagate)
n_prepended_slots, length(overdubbed_code), :propagate)

# For the rest of the statements in `code_info.code`, intercept every applicable call
# expression and replace it with a corresponding call to `Cassette.overdub`.
old_code_start_index = length(overdubbed_code) + 1

# TODO: This `iskwfunc` is a hack in order to implement a fix for jrevels/Cassette.jl#48.
# It assumes that 1) `Core.kwfunc(f)` is going to return a function whose type name
# is prefixed by `#kw##` and 2) that the second to last statement in the lowered IR
# for `Core.kwfunc(f)` is the call to the "underlying" non-kwargs form of `f`. These
# assumptions are obviously quite fragile, so we should eventually get Base to expose
# a standard/documented API for this.
iskwfunc = startswith(String(signature.parameters[1], "#kw##")
for i in 1:length(code_info.code)
stmnt = code_info.code[i]
replaceable = Base.Meta.isexpr(stmnt, :foreigncall) ? view(stmnt.args, 2:length(stmnt.args)) : stmnt
replacement = iskwfunc && i !== (length(code_info.code)-1) ? :call : :overdub
replace_match!(is_call, replaceable) do call
call.args = Any[GlobalRef(Cassette, replacement), overdub_ctx_slot, call.args...]
return call
push!(overdubbed_code, stmnt)
push!(overdubbed_codelocs, code_info.codelocs[i])
append!(overdubbed_code, code_info.code)
append!(overdubbed_codelocs, code_info.codelocs)

#=== 4. IR transforms for the metadata tagging system ===#
#=== TODO: perform tagged module transformation if tagging is enabled ===#

if has_tagging_enabled(context_type) && !iskwfunc
# changemap = fill(0, length(code_info.code))
# Scan the IR for `Module`s in the first argument position for `GlobalRef`s.
# For every unique such `Module`, make a new `SSAValue` at the top of the method body
# corresponding to `Cassette.fetch_tagged_module` called with the given context and
# module. Then, replace all `GlobalRef`-loads with the corresponding
# `Cassette._tagged_global_ref` invocation. All `GlobalRef`-stores must be preserved
# as-is, but need a follow-up statement calling `Cassette._tagged_global_ref_set_meta!`
# on the relevant arguments.

# TODO: Scan the IR for `Module`s in the first argument position for `GlobalRef`s.
# For every unique such `Module`, make a new `SSAValue` at the top of the method body
# corresponding to `Cassette.fetch_tagged_module` called with the given context and
# module. Then, replace all `GlobalRef`-loads with the corresponding
# `Cassette._tagged_global_ref` invocation. All `GlobalRef`-stores must be preserved
# as-is, but need a follow-up statement calling
# `Cassette._tagged_global_ref_set_meta!` on the relevant arguments.
#=== TODO: untag all `ccall` SSAValue arguments if tagging is enabled ===#

replace_match!(is_new, overdubbed_code) do x
return Expr(:call, GlobalRef(Cassette, :tagged_new), overdub_ctx_slot, x.args...)
#=== untag `gotoifnot` conditionals if tagging is enabled ===#

# this sentinel is consumed by in a call-replacement pass below; we use
# it so that we don't accidentally overdub the calls we've inserted
untag_call_sentinel = :REPLACE_ME_WITH_CASSETTE_UNTAG
if istaggingenabled && !iskwfunc
insert_ir_elements!(overdubbed_code, overdubbed_codelocs, 1,
(x, i) -> Base.Meta.isexpr(x, :gotoifnot),
(x, i) -> [
Expr(:call, untag_call_sentinel, x.args[1], overdub_ctx_slot),
Expr(:gotoifnot, SSAValue(i), x.args[2])

#=== replace `Expr(:call, ...)` with `Expr(:call, :overdub, ...)` calls ===#

if iskwfunc
# Another assumption of this `iskwfunc` hack is that the second to last statement in
# the lowered IR for `Core.kwfunc(f)` is the call to the "underlying" non-kwargs form
# of `f`. Thus, we `overdub` that call instead of replacing it with `call`.
for i in 1:length(overdubbed_code)
stmt = overdubbed_code[i]
replacewith = i === (length(overdubbed_code) - 1) ? :overdub : :call
if Base.Meta.isexpr(stmt, :(=))
replacein = stmt.args
replaceat = 2
replacein = overdubbed_code
replaceat = i
stmt = replacein[replaceat]
if Base.Meta.isexpr(stmt, :call)
replacein[replaceat] = Expr(:call, GlobalRef(Cassette, replacewith), overdub_ctx_slot, stmt.args...)
predicate = (x, i) -> begin
i >= old_code_start_index || return false
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
return Base.Meta.isexpr(stmt, :call) && stmt.args[1] !== untag_call_sentinel
itemfunc = (x, i) -> begin
callstmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
execstmt = Expr(:call, GlobalRef(Cassette, :execute), overdub_ctx_slot, callstmt.args...)
overdubstmt = Expr(:call, GlobalRef(Cassette, :overdub), overdub_ctx_slot, callstmt.args...)
return [
Expr(:call, GlobalRef(Cassette, :prehook), overdub_ctx_slot, callstmt.args...),
Expr(:(=), overdub_tmp_slot, execstmt),
Expr(:call, GlobalRef(Core, :isa), overdub_tmp_slot, GlobalRef(Cassette, :OverdubInstead)),
Expr(:gotoifnot, SSAValue(i + 2), i + 6),
Expr(:(=), overdub_tmp_slot, overdubstmt),
Expr(:call, GlobalRef(Cassette, :posthook), overdub_ctx_slot, overdub_tmp_slot, callstmt.args...),
Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], overdub_tmp_slot) : overdub_tmp_slot
insert_ir_elements!(overdubbed_code, overdubbed_codelocs, 6, predicate, itemfunc)

# TODO: appropriately untag all `gotoifnot` conditionals
#=== replace `untag_call_sentinel` with `GlobalRef(Cassette, :untag)` ===#

# TODO: appropriately untag all `ccall` arguments
if istaggingenabled && !iskwfunc
replace_match!(x -> Base.Meta.isexpr(x, :call) && x.args[1] == untag_call_sentinel, overdubbed_code) do x
return Expr(:call, GlobalRef(Cassette, :untag), x.args[2:end]...)

#=== replace `Expr(:new, ...)` with `Expr(:call, :tagged_new)` if tagging is enabled ===#

# Core.Compiler.renumber_ir_elements!(overdubbed_code, changemap)
if istaggingenabled && !iskwfunc
replace_match!(x -> Base.Meta.isexpr(x, :new), overdubbed_code) do x
return Expr(:call, GlobalRef(Cassette, :tagged_new), overdub_ctx_slot, x.args...)

#=== 5. Set `code_info`/`reflection` fields accordingly ===#
#=== set `code_info`/`reflection` fields accordingly ===#

code_info.code = overdubbed_code
code_info.codelocs = overdubbed_codelocs
Expand All @@ -199,13 +229,13 @@ function recurse_pass!(reflection::Reflection,

# `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)`
function recurse_generator(pass_type, self, context_type, args::Tuple)
if !(nfields(args) > 1 && args[1] <: Core.Builtin)
function overdub_generator(pass_type, self, context_type, args::Tuple)
if !(nfields(args) > 0 && args[1] <: Core.Builtin)
untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,)
reflection = reflect(untagged_args)
if isa(reflection, Reflection)
recurse_pass!(reflection, context_type, pass_type)
overdub_pass!(reflection, context_type, pass_type)
body = reflection.code_info
@safe_debug "returning overdubbed CodeInfo" args body
return body
Expand All @@ -220,25 +250,36 @@ function recurse_generator(pass_type, self, context_type, args::Tuple)
@safe_debug "no CodeInfo found; executing via fallback" args
return quote
$(Expr(:meta, :inline))

function recurse_definition(pass, line, file)
@inline apply_args(::ContextWithTag{Nothing}, args...) = Core._apply(Core.tuple, args...)
@inline apply_args(ctx::Context, args...) = tagged_apply_args(ctx, args...)

function overdub_definition(pass, line, file)
return quote
function recurse($RECURSE_CTX_SYMBOL::ContextWithPass{pass}, $RECURSE_ARGS_SYMBOL...) where {pass<:$pass}
function overdub($OVERDUB_CTX_SYMBOL::ContextWithPass{pass}, $OVERDUB_ARGS_SYMBOL...) where {pass<:$pass}
@inline function overdub(ctx::ContextWithPass{pass}, ::typeof(Core._apply), f, _args...) where {pass<:$pass}
args = apply_args(ctx, _args...)
prehook(ctx, f, args...)
output = execute(ctx, f, args...)
output = isa(output, OverdubInstead) ? overdub(ctx, f, args...) : output
posthook(ctx, output, f, args...)
return output

@eval $(recurse_definition(:NoPass, @__LINE__, @__FILE__))
@eval $(overdub_definition(:NoPass, @__LINE__, @__FILE__))
5 changes: 3 additions & 2 deletions src/tagged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ end

# TODO: For fast methods (~ns), this fetch can cost drastically more than the primal method
# invocation. We easily have the module at compile time, but we don't have access to the
# actual context object. This `@pure` is vtjnash-approved. It should allow the compiler to
# optimize away the fetch once we have support for it, e.g. loop invariant code motion.
# actual context object (just the type). This `@pure` is vtjnash-approved. It should allow
# the compiler to optimize away the fetch once we have support for it, e.g. loop invariant
# code motion.
Base.@pure @noinline function fetch_tagged_module(context::Context, m::Module)
bindings = get!(() -> BindingMetaDict(), context.bindings, m)
return Tagged(context, m, Meta(NoMetaData(), ModuleMeta(NOMETA, bindings)))
Expand Down

0 comments on commit db83720

Please sign in to comment.