You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Zygote-style compilation with added optimizations. Perhaps, it's best to try it in Zygote itself. If I'm not mistaken, there was a plan to move Zygote to Cassette which would make experimenting with Yota-style optimizations in Zygote easier.
Re-tracing on every call. We need to measure what overhead it creates in practice and think out caching layers. E.g. if after tracing we get a graph that we've already had, we can avoid recalculating gradient nodes and just retrieve the full tape from cache. Caching tracer results is a harder task though.
All in all, it should be possible to have 2 workflows - fast static graphs or slower dynamic graphs.
The text was updated successfully, but these errors were encountered:
Re-tracing looks somewhat promising. "Promising" means that it's not as terribly slow as I thought, "somewhat" means tracing is still ~10-20 times slower than normal execution.
(For now I assume we can cache gradient calculation for a known tape so that it comes almost for free).
Here's a short test I used:
using BenchmarkTools
using Profile
struct Linear{T}
W::AbstractMatrix{T}
b::AbstractVector{T}
end
Base.show(io::IO, m::Linear) = print(io, "Linear($(reverse(size(m.W))))")
(m::Linear)(x::AbstractArray) = m.W * x .+ m.b
forward(m::Linear, x) = m.W * x .+ m.b
loss(m::Linear, x::AbstractArray) = sum(forward(m, x))
function main()
m = Linear(rand(128, 784), rand(128))
x = rand(784, 100)
@btime loss(m, x)
@btime trace(loss, m, x)
@profile trace(loss, m)
end
Profiler tells us that most of the time is spent in Cassette internals, 50% on overdub/recurse and 50% on tagging. Presumably, these will be improved in future and tracer will become faster automatically.
I'm going to try out dynamic re-tracing in a branch to see if there are any significant gotchas.
All in all, it should be possible to have 2 workflows - fast static graphs or slower dynamic graphs.
The text was updated successfully, but these errors were encountered: