-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
julia> derivative(x -> fs[readline()](x), 1) running, won't stop #4
Comments
You know that |
Would you post all the information I need to reproduce this example? |
Keno
added a commit
that referenced
this issue
Feb 24, 2019
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
Keno
added a commit
that referenced
this issue
Mar 6, 2019
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://gist.github.com/hpoit/dc0b2db9f4dffb77403497f673b7e26d
The text was updated successfully, but these errors were encountered: