-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.dynamic_slice inside jit #1007
Comments
No, it shouldn't work: actually it's not just that XLA (and JAX's Here are a two alternatives, both of which you probably know about: from __future__ import print_function
from functools import partial
import jax
import jax.numpy as np
@partial(jax.jit, static_argnums=(1,))
def sum_first_k(a, k):
return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))
print(sum_first_k(np.arange(3.0), 2))
@jax.jit
def sum_first_k(a, k):
n = len(a)
return np.sum(np.where(np.arange(n) < k, a, 0))
print(sum_first_k(np.arange(3.0), 2)) The first is a way of solving the problem with recompilation. The second is a way to solve it with masking, for which XLA can still generate very efficient code by fusing the selection into the reduction rather than round-tripping several arrays to memory. A third strategy is to use a loop construct. WDYT? |
It would be nice if the error message said something like this, rather than
sending me down a rabbit hole. What actually happened is that I first tried
using indexing like a[:k], which generated an error encouraging me to try
lax.dyanmic_slice.
…On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson ***@***.***> wrote:
No, it shouldn't work: actually it's not just that XLA (and JAX's jit,
which is what's actually raising the error here for tracing reasons)
require fixed output shapes, but all the shapes of the intermediates need
to be fixed too. So summing the output doesn't help; that
lax.dynamic_slice alone is a problem.
Here are a two alternatives, both of which you probably know about:
from __future__ import print_functionfrom functools import partial
import jaximport jax.numpy as np
@partial(jax.jit, static_argnums=(1,))def sum_first_k(a, k):
return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))
print(sum_first_k(np.arange(3.0), 2))
@jax.jitdef sum_first_k(a, k):
n = len(a)
return np.sum(np.where(np.arange(n) < k, a, 0))
print(sum_first_k(np.arange(3.0), 2))
The first is a way of solving the problem with recompilation. The second
is a way to solve it with masking, for which XLA can still generate very
efficient code by fusing the selection into the reduction rather than
round-tripping several arrays to memory. A third strategy is to use a loop
construct.
WDYT?
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#1007?email_source=notifications&email_token=AAJJFVXUEEPZYGHSO7K6WE3P6VS7HA5CNFSM4H7I6LE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZSI4PI#issuecomment-509906493>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAJJFVWEJNLZMCW3JTWJYJTP6VS7HANCNFSM4H7I6LEQ>
.
|
Is there a way to make the construct with static_argnums work with sum_first_k inside vmap? Calling sum_first_k as a standalone function works fine: But it doesn't work when calling it from vmap TypeError: Abstract value passed to |
It's a little awkward, but you can make something like this work using explicit masking: @jax.jit
def sum_first_k(a, k):
return np.sum(a * (np.arange(a.size) < k)) |
Ok, thanks, that works. But it is slow in my use case: a dimension is roughly 3e6 x 100 while each k is a slice, not a number, with different slice lengths where median slice length is about 50e3. So explicit masking results in huge number of multiplications by 0. Which is too bad - other parts of my program are unbelievably fast using jax with gpu, much faster than anything else I tried. |
I think it's worth adding that I ran into the same issue as shoyer@ above, where I want dynamic |
Fixes googleGH-1007 This should clarify the underlying concerns from googleGH-1007. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
This should clarify the underlying issues from google#1007 and google#3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
This should clarify the underlying issues from google#1007 and google#3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
This is a good suggestion! I added clarifications to both the I'm close this issue, since hopefully users will see the more descriptive errors in the future and won't be misled. |
This should clarify the underlying issues from google#1007 and google#3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
This should clarify the underlying issues from google#1007 and google#3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
This should clarify the underlying issues from google#1007 and google#3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page.
@shoyer Is there a workaround for situations where masking doesn't apply? My end goal here it to define a train/test split function that is
|
There are probably work-arounds in other specific cases, but this is a pretty fundamental limitation in XLA. |
Should this work?
Here's the traceback I get:
I know XLA can't have variable sized outputs, but here I'm summing the outputs, so in principle that shouldn't be an issue.
The text was updated successfully, but these errors were encountered: