Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore possible solutions for dynamic graph support #29

Closed
dfdx opened this issue Feb 28, 2019 · 2 comments
Closed

Explore possible solutions for dynamic graph support #29

dfdx opened this issue Feb 28, 2019 · 2 comments

Comments

@dfdx
Copy link
Owner

dfdx commented Feb 28, 2019

  1. Zygote-style compilation with added optimizations. Perhaps, it's best to try it in Zygote itself. If I'm not mistaken, there was a plan to move Zygote to Cassette which would make experimenting with Yota-style optimizations in Zygote easier.
  2. Re-tracing on every call. We need to measure what overhead it creates in practice and think out caching layers. E.g. if after tracing we get a graph that we've already had, we can avoid recalculating gradient nodes and just retrieve the full tape from cache. Caching tracer results is a harder task though.

All in all, it should be possible to have 2 workflows - fast static graphs or slower dynamic graphs.

@dfdx
Copy link
Owner Author

dfdx commented Mar 6, 2019

Re-tracing looks somewhat promising. "Promising" means that it's not as terribly slow as I thought, "somewhat" means tracing is still ~10-20 times slower than normal execution.

(For now I assume we can cache gradient calculation for a known tape so that it comes almost for free).

Here's a short test I used:

using BenchmarkTools
using Profile

struct Linear{T}
    W::AbstractMatrix{T}
    b::AbstractVector{T}
end
Base.show(io::IO, m::Linear) = print(io, "Linear($(reverse(size(m.W))))")


(m::Linear)(x::AbstractArray) = m.W * x .+ m.b
forward(m::Linear, x) = m.W * x .+ m.b

loss(m::Linear, x::AbstractArray) = sum(forward(m, x))

function main()
    m = Linear(rand(128, 784), rand(128))
    x = rand(784, 100)    

    @btime loss(m, x)
    @btime trace(loss, m, x)
    @profile trace(loss, m)
end

Profiler tells us that most of the time is spent in Cassette internals, 50% on overdub/recurse and 50% on tagging. Presumably, these will be improved in future and tracer will become faster automatically.

I'm going to try out dynamic re-tracing in a branch to see if there are any significant gotchas.

@dfdx
Copy link
Owner Author

dfdx commented Mar 14, 2019

Implemented in #32.

@dfdx dfdx closed this as completed Mar 14, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant