Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.dynamic_slice inside jit #1007

Closed
shoyer opened this issue Jul 9, 2019 · 9 comments
Closed

lax.dynamic_slice inside jit #1007

shoyer opened this issue Jul 9, 2019 · 9 comments
Assignees
Labels
documentation question Questions for the JAX team

Comments

@shoyer
Copy link
Member

shoyer commented Jul 9, 2019

Should this work?

import jax
import jax.numpy as np

@jax.jit
def sum_first_k(a, k):
  return np.sum(lax.dynamic_slice(a, (0,), (k,)))

sum_first_k(np.arange(3.0), 2)

Here's the traceback I get:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-167-645715d2be42> in <module>()
----> 1 sum_first_k(np.arange(3.0), 2)

13 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
    121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    124     return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
    125 

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    661   if top_trace is None:
    662     with new_sublevel():
--> 663       ans = primitive.impl(f, *args, **params)
    664   else:
    665     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
    604 def xla_call_impl(fun, *args, **params):
    605   device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 606   compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    607   try:
    608     return compiled_fun(*args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
    617   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    618   with core.new_master(pe.JaxprTrace, True) as master:
--> 619     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    620     assert not env  # no subtraces here (though cond might eventually need them)
    621     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    145 
    146     del gen
--> 147     ans = self.f(*args, **dict(self.params, **kwargs))
    148     del args
    149     while stack:

<ipython-input-165-9a17ef1ee2d8> in sum_first_k(a, k)
      1 @jax.jit
      2 def sum_first_k(a, k):
----> 3   return np.sum(lax.dynamic_slice(a, (0,), (k,)))

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in dynamic_slice(operand, start_indices, slice_sizes)
    607   return dynamic_slice_p.bind(
    608       operand, start_indices, slice_sizes=tuple(slice_sizes),
--> 609       operand_shape=operand.shape)
    610 
    611 def dynamic_update_slice(operand, update, start_indices):

/usr/local/lib/python3.6/dist-packages/jax/core.py in bind(self, *args, **kwargs)
    145 
    146     tracers = map(top_trace.full_raise, args)
--> 147     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    148     return full_lower(out_tracer)
    149 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
    100       tracers = map(self.instantiate_const, tracers)
    101       avals = [t.aval for t in tracers]
--> 102       out_aval = primitive.abstract_eval(*avals, **params)
    103       eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
    104       return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs)
   1405     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1406   elif least_specialized is ShapedArray:
-> 1407     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1408   elif least_specialized is UnshapedArray:
   1409     return UnshapedArray(dtype_rule(*args, **kwargs))

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _dynamic_slice_shape_rule(operand, start_indices, slice_sizes, operand_shape)
   2608            "start_indices, got start_inidices length {} and slice_sizes {}.")
   2609     raise TypeError(msg.format(len(start_indices), slice_sizes))
-> 2610   if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
   2611     msg = ("slice slice_sizes must be less than or equal to operand shape, "
   2612            "got slice_sizes {} for operand shape {}.")

/usr/local/lib/python3.6/dist-packages/jax/core.py in __bool__(self)
    340   def __getitem__(self, idx): return self.aval._getitem(self, idx)
    341   def __nonzero__(self): return self.aval._nonzero(self)
--> 342   def __bool__(self): return self.aval._bool(self)
    343   def __float__(self): return self.aval._float(self)
    344   def __int__(self): return self.aval._int(self)

/usr/local/lib/python3.6/dist-packages/jax/abstract_arrays.py in error(self, *args)
     36 def concretization_function_error(fun):
     37   def error(self, *args):
---> 38     raise TypeError(concretization_err_msg(fun))
     39   return error
     40 

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

I know XLA can't have variable sized outputs, but here I'm summing the outputs, so in principle that shouldn't be an issue.

@mattjj
Copy link
Member

mattjj commented Jul 10, 2019

No, it shouldn't work: actually it's not just that XLA (and JAX's jit, which is what's actually raising the error here for tracing reasons) require fixed output shapes, but all the shapes of the intermediates need to be fixed too. So summing the output doesn't help; that lax.dynamic_slice alone is a problem.

Here are a two alternatives, both of which you probably know about:

from __future__ import print_function
from functools import partial

import jax
import jax.numpy as np

@partial(jax.jit, static_argnums=(1,))
def sum_first_k(a, k):
  return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))

print(sum_first_k(np.arange(3.0), 2))


@jax.jit
def sum_first_k(a, k):
  n = len(a)
  return np.sum(np.where(np.arange(n) < k, a, 0))

print(sum_first_k(np.arange(3.0), 2))

The first is a way of solving the problem with recompilation. The second is a way to solve it with masking, for which XLA can still generate very efficient code by fusing the selection into the reduction rather than round-tripping several arrays to memory. A third strategy is to use a loop construct.

WDYT?

@mattjj mattjj added the question Questions for the JAX team label Jul 10, 2019
@mattjj mattjj self-assigned this Jul 10, 2019
@shoyer
Copy link
Member Author

shoyer commented Jul 10, 2019 via email

@viktor2
Copy link

viktor2 commented Dec 18, 2019

Is there a way to make the construct with static_argnums work with sum_first_k inside vmap?

Calling sum_first_k as a standalone function works fine:
print(sum_first_k(np.arange(3.0), 2))

But it doesn't work when calling it from vmap
vmap_sum_first_k = jax.vmap(sum_first_k,(None,0))
print(vmap_sum_first_k(np.arange(10.0), np.arange(4)))

TypeError: Abstract value passed to bool, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using jit, try using static_argnums or applying jit to smaller subfunctions instead.

@shoyer
Copy link
Member Author

shoyer commented Dec 18, 2019

It's a little awkward, but you can make something like this work using explicit masking:

@jax.jit
def sum_first_k(a, k):
  return np.sum(a * (np.arange(a.size) < k))

@viktor2
Copy link

viktor2 commented Dec 18, 2019

Ok, thanks, that works. But it is slow in my use case: a dimension is roughly 3e6 x 100 while each k is a slice, not a number, with different slice lengths where median slice length is about 50e3. So explicit masking results in huge number of multiplications by 0. Which is too bad - other parts of my program are unbelievably fast using jax with gpu, much faster than anything else I tried.

@juesato
Copy link
Contributor

juesato commented May 12, 2020

I think it's worth adding that slice_sizes needs to be static to the dynamic_slice() docstring. I can send in a PR if that sounds good, WDYT?

I ran into the same issue as shoyer@ above, where I want dynamic slice_sizes() and first tried indexing a[:k], then was told to use dynamic_slice(), got this error message, poked around a bit, and then ended up here.

shoyer added a commit to shoyer/jax that referenced this issue Jul 19, 2020
Fixes googleGH-1007

This should clarify the underlying concerns from googleGH-1007.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental
or if we had a dedicated doc page.
shoyer added a commit to shoyer/jax that referenced this issue Jul 19, 2020
This should clarify the underlying issues from google#1007 and google#3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
shoyer added a commit to shoyer/jax that referenced this issue Jul 19, 2020
This should clarify the underlying issues from google#1007 and google#3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
hawkinsp pushed a commit that referenced this issue Jul 20, 2020
This should clarify the underlying issues from #1007 and #3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
@shoyer
Copy link
Member Author

shoyer commented Jul 20, 2020

I think it's worth adding that slice_sizes needs to be static to the dynamic_slice() docstring. I can send in a PR if that sounds good, WDYT?

This is a good suggestion! I added clarifications to both the dynamic_slice documentation and this error message in: #3795

I'm close this issue, since hopefully users will see the more descriptive errors in the future and won't be misled.

@shoyer shoyer closed this as completed Jul 20, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this issue Jul 21, 2020
This should clarify the underlying issues from google#1007 and google#3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this issue Jul 24, 2020
This should clarify the underlying issues from google#1007 and google#3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this issue Jul 24, 2020
This should clarify the underlying issues from google#1007 and google#3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
@pharringtonp19
Copy link

@shoyer Is there a workaround for situations where masking doesn't apply? My end goal here it to define a train/test split function that is jit-able and which I could apply vmap over.

It's a little awkward, but you can make something like this work using explicit masking:

@jax.jit
def sum_first_k(a, k):
return np.sum(a * (np.arange(a.size) < k))

@shoyer
Copy link
Member Author

shoyer commented Aug 2, 2021

@shoyer Is there a workaround for situations where masking doesn't apply? My end goal here it to define a train/test split function that is jit-able and which I could apply vmap over.

There are probably work-arounds in other specific cases, but this is a pretty fundamental limitation in XLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

5 participants