You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from the previouos threads, we know that the main factor for slow inference is "change in input tensor size causing recompilation", now if you may, i would like to break down this statement in a more clear way:
notice for the following code: (ps: it is the colab format i am using)
Concretely, in which call, the input tensor size shall be fixed to avoid recompliation? (is it inside inference() or in highlight 1 or other)
where indeed is compilation/recompilation happens?
Conjecture:
if it is "jax.jit" that causes compilation, then supposably from highlight 2, a compiled version of model_apply is returned. After this, no other jax.jit is called, we simply enter a for loop that continues calls for "inference()". Everytime inference is called, it used the pre-compiled version of "model_apply", it do not have access to the outside "jax.jit". So where exactly does this recompilation stem from?
Much thanks to one who read through!
The text was updated successfully, but these errors were encountered:
This is standard behavior of JAX. A function that's wrapped in jit will be compiled the first time it runs. Then the compiled version will be cached. Thus, if the jitted function is called with arguments of the same sizes and types, then it will use the cached version. Otherwise it needs to recompile. Every function in python has associated state (i.e., jax.jit returns a function which is a closure), so references to the compiled computation graph are stored in the state associated with the function.
In the future, this kind of question might be a better fit for a the JAX repository, as this behavior isn't specific to TAP/TAPIR.
Hi everyone,
from the previouos threads, we know that the main factor for slow inference is "change in input tensor size causing recompilation", now if you may, i would like to break down this statement in a more clear way:
notice for the following code: (ps: it is the colab format i am using)
May i ask:
Conjecture:
if it is "jax.jit" that causes compilation, then supposably from highlight 2, a compiled version of model_apply is returned. After this, no other jax.jit is called, we simply enter a for loop that continues calls for "inference()". Everytime inference is called, it used the pre-compiled version of "model_apply", it do not have access to the outside "jax.jit". So where exactly does this recompilation stem from?
Much thanks to one who read through!
The text was updated successfully, but these errors were encountered: