Skip to content

Commit

Permalink
performance
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Dec 4, 2019
1 parent d996829 commit e8b4686
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions jax/interpreters/xla.py
Expand Up @@ -181,17 +181,20 @@ def stage_lazy_expr(c, lazy_expr, x):
ArgSpec = namedtuple("ArgSpec", ["aval", "lazy_expr", "xla_shape"])

def arg_spec(x):
# For performance reasons, the return type of this function is polymorphic:
# if the input has a lazy_expr attached, the return type is a raw tuple
# modeling an ArgSpec namedtuple; otherwise, it's an AbstractValue.
aval = abstractify(x)
try:
lazy_expr = x._lazy_expr
except AttributeError:
return ArgSpec(aval, None, aval_to_xla_shape(aval))
return aval
else:
if x.device_buffer is device_constant:
xla_shape = None
else:
xla_shape = x.device_buffer.shape()
return ArgSpec(aval, x._lazy_expr, xla_shape)
return aval, x._lazy_expr, xla_shape


### handlers
Expand Down Expand Up @@ -282,6 +285,12 @@ def apply_primitive(prim, *args, **params):

@cache()
def xla_primitive_callable(prim, *arg_specs, **params):
# For performance, we allow the elements of arg_specs either to be triples
# modeling ArgSpec instances or abstract values. Here we canonicalize them.
arg_specs = [ArgSpec(x, None, aval_to_xla_shape(x))
if isinstance(x, core.AbstractValue) else ArgSpec(*x)
for x in arg_specs]

if FLAGS.jax_log_compiles:
print("Compiling {} for args {}.".format(prim.name, arg_specs))
backend = params.get('backend', None)
Expand Down

0 comments on commit e8b4686

Please sign in to comment.