diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 35419e85c0d3..9eb8bf2278dc 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -181,17 +181,20 @@ def stage_lazy_expr(c, lazy_expr, x): ArgSpec = namedtuple("ArgSpec", ["aval", "lazy_expr", "xla_shape"]) def arg_spec(x): + # For performance reasons, the return type of this function is polymorphic: + # if the input has a lazy_expr attached, the return type is a raw tuple + # modeling an ArgSpec namedtuple; otherwise, it's an AbstractValue. aval = abstractify(x) try: lazy_expr = x._lazy_expr except AttributeError: - return ArgSpec(aval, None, aval_to_xla_shape(aval)) + return aval else: if x.device_buffer is device_constant: xla_shape = None else: xla_shape = x.device_buffer.shape() - return ArgSpec(aval, x._lazy_expr, xla_shape) + return aval, x._lazy_expr, xla_shape ### handlers @@ -282,6 +285,12 @@ def apply_primitive(prim, *args, **params): @cache() def xla_primitive_callable(prim, *arg_specs, **params): + # For performance, we allow the elements of arg_specs either to be triples + # modeling ArgSpec instances or abstract values. Here we canonicalize them. + arg_specs = [ArgSpec(x, None, aval_to_xla_shape(x)) + if isinstance(x, core.AbstractValue) else ArgSpec(*x) + for x in arg_specs] + if FLAGS.jax_log_compiles: print("Compiling {} for args {}.".format(prim.name, arg_specs)) backend = params.get('backend', None)