Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

what is going on when "recompilation" happens in a for loop #44

Closed
CHYjeremy opened this issue Aug 6, 2023 · 2 comments
Closed

what is going on when "recompilation" happens in a for loop #44

CHYjeremy opened this issue Aug 6, 2023 · 2 comments

Comments

@CHYjeremy
Copy link

CHYjeremy commented Aug 6, 2023

Hi everyone,

from the previouos threads, we know that the main factor for slow inference is "change in input tensor size causing recompilation", now if you may, i would like to break down this statement in a more clear way:

notice for the following code: (ps: it is the colab format i am using)

def inference():
    ...
    rng = jax.random.PRNGKey(42)
    outputs, _ = model_apply(params, state, rng, frames, query_points)   ## highlight 1
    ...
    return ...

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)   ## highlight 2

for video in videos:
    ...
    tracks, visibles = **inference**(frames, query_points)  ## highlight 3

May i ask:

  1. Concretely, in which call, the input tensor size shall be fixed to avoid recompliation? (is it inside inference() or in highlight 1 or other)
  2. where indeed is compilation/recompilation happens?

Conjecture:

if it is "jax.jit" that causes compilation, then supposably from highlight 2, a compiled version of model_apply is returned. After this, no other jax.jit is called, we simply enter a for loop that continues calls for "inference()". Everytime inference is called, it used the pre-compiled version of "model_apply", it do not have access to the outside "jax.jit". So where exactly does this recompilation stem from?

Much thanks to one who read through!

@cdoersch
Copy link
Collaborator

cdoersch commented Aug 9, 2023

This is standard behavior of JAX. A function that's wrapped in jit will be compiled the first time it runs. Then the compiled version will be cached. Thus, if the jitted function is called with arguments of the same sizes and types, then it will use the cached version. Otherwise it needs to recompile. Every function in python has associated state (i.e., jax.jit returns a function which is a closure), so references to the compiled computation graph are stored in the state associated with the function.

In the future, this kind of question might be a better fit for a the JAX repository, as this behavior isn't specific to TAP/TAPIR.

@CHYjeremy
Copy link
Author

Thanks @cdoersch , i will pay close attention to where i post issues to. Thank you for your explanations on it, clear and concise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants