Skip to content

Commit

Permalink
make jit strict in its arguments (i.e. force args)
Browse files Browse the repository at this point in the history
This change is to avoid recompiles. See comment:
#1668 (comment)
Thanks @hawkinsp for help with this.

Also, make force(x) update x's device_buffer reference.
  • Loading branch information
mattjj committed Dec 5, 2019
1 parent 3682658 commit dc93867
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 39 deletions.
2 changes: 1 addition & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def computation_maker(*args, **kwargs):
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
xla_consts = map(c.Constant, consts)
xla_args = xla._xla_callable_args(c, arg_specs, tuple_args)
xla_args = xla._xla_callable_args(c, avals, tuple_args)
outs = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env_, xla_consts, (),
*xla_args)
return c.Build(c.Tuple(*outs))
Expand Down
6 changes: 2 additions & 4 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,8 @@ class ResultToPopulate(object): pass
result_to_populate = ResultToPopulate()

def _pmap_callable_args(c, avals, tuple_args):
# TODO(mattjj): support laziness here, don't just force every argument
arg_specs = [xla.ArgSpec(aval, None, aval_to_xla_shape(aval))
for aval in avals]
return xla._xla_callable_args(c, arg_specs, tuple_args)
# TODO(mattjj): support laziness for broadcasted axes to map
return xla._xla_callable_args(c, avals, tuple_args)

def _pvals_to_results_handler(size, nrep, out_pvals):
nouts = len(out_pvals)
Expand Down
56 changes: 33 additions & 23 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,16 @@ def xla_primitive_callable(prim, *arg_specs, **params):

@cache()
def primitive_computation(prim, *avals, **params):
arg_specs = [ArgSpec(a, lazy_array(a.shape), aval_to_xla_shape(a)) for a in avals]
# This function is used when compiling sub-computations, like in reductions.
arg_specs = [ArgSpec(a, None, aval_to_xla_shape(a)) for a in avals]
return _primitive_computation(prim, *arg_specs, **params)

def _primitive_computation(prim, *arg_specs, **params):
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
c.SetOpMetadata(xc.OpMetadata(op_type=prim.name, op_name=str(params)))
backend = params.pop("backend", None)
platform = xb.get_backend(backend).platform
xla_args = _xla_callable_args(c, arg_specs, False)
xla_args = _xla_callable_args_lazy(c, arg_specs)
if prim in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][prim]
rule(c, *xla_args, **params) # return val set as a side-effect on c
Expand All @@ -332,6 +333,15 @@ def _primitive_computation(prim, *arg_specs, **params):
"https://github.com/google/jax/issues\n")
raise RuntimeError(msg)

def _xla_callable_args_lazy(c, arg_specs):
raw_args = (c.ParameterWithShape(s.xla_shape) for s in arg_specs
if s.aval is not abstract_token and s.xla_shape is not None)
xla_args = [stage_lazy_expr(c, s.lazy_expr, s.xla_shape and next(raw_args))
if s.aval is not abstract_token else c.CreateToken()
for s in arg_specs]
assert next(raw_args, None) is None
return xla_args

def _execute_compiled_primitive(prim, compiled, backend, result_handler, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args
Expand Down Expand Up @@ -520,7 +530,7 @@ def eqn_has_pmap(eqn):
def _xla_call_impl(fun, *args, **params):
device = params['device']
backend = params.get('backend', None)
compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
try:
return compiled_fun(*args)
except FloatingPointError:
Expand All @@ -529,12 +539,12 @@ def _xla_call_impl(fun, *args, **params):
return fun.call_wrapped(*args) # probably won't return

@lu.cache
def _xla_callable(fun, device, backend, *arg_specs):
def _xla_callable(fun, device, backend, *abstract_args):
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
"Compiling {} for args {}.".format(fun.__name__, arg_specs))
"Compiling {} for args {}.".format(fun.__name__, abstract_args))

pvals = [pe.PartialVal((s.aval, core.unit)) for s in arg_specs]
pvals = [pe.PartialVal((a, core.unit)) for a in abstract_args]
with core.new_master(pe.JaxprTrace, True) as master:
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
assert not env # no subtraces here
Expand All @@ -553,11 +563,11 @@ def _xla_callable(fun, device, backend, *arg_specs):
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")

tuple_args = len(arg_specs) > 100 # pass long arg lists as tuple for TPU
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU

c = xb.make_computation_builder("jit_{}".format(fun.__name__))
xla_consts = _map(c.Constant, consts)
xla_args = _xla_callable_args(c, arg_specs, tuple_args)
xla_args = _xla_callable_args(c, abstract_args, tuple_args)
out_nodes = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, (), *xla_args)
built = c.Build(c.Tuple(*out_nodes))

Expand All @@ -573,18 +583,16 @@ def _xla_callable(fun, device, backend, *arg_specs):
else:
return partial(_execute_replicated, compiled, backend, result_handlers, tuple_args)

def _xla_callable_args(c, arg_specs, tuple_args):
def _xla_callable_args(c, avals, tuple_args):
if not tuple_args:
raw_args = (c.ParameterWithShape(s.xla_shape) for s in arg_specs
if s.aval is not abstract_token and s.xla_shape is not None)
raw_args = (c.ParameterWithShape(aval_to_xla_shape(a)) for a in avals
if a is not abstract_token)
else:
elt_shapes = [s.xla_shape for s in arg_specs
if s.aval is not abstract_token and s.xla_shape is not None]
elt_shapes = [aval_to_xla_shape(a) for a in avals if a is not abstract_token]
tuple_param = c.ParameterWithShape(xc.Shape.tuple_shape(elt_shapes))
raw_args = iter(xla_destructure(c, tuple_param))
xla_args = [stage_lazy_expr(c, s.lazy_expr, s.xla_shape and next(raw_args))
if s.aval is not abstract_token else c.CreateToken()
for s in arg_specs]
xla_args = [next(raw_args) if a is not abstract_token else c.CreateToken()
for a in avals]
assert next(raw_args, None) is None
return xla_args

Expand All @@ -597,8 +605,7 @@ def _pval_to_result_handler(pval):

def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args
if x is not token and not is_device_constant(x)]
input_bufs = [device_put(force(x), device) for x in args if x is not token]
if tuple_args:
input_bufs = [make_tuple(input_bufs, device, backend)]
out_bufs = compiled.Execute(input_bufs).destructure()
Expand Down Expand Up @@ -765,8 +772,7 @@ def _value(self):
if self.device_buffer is device_constant:
self._npy_value = eval_lazy_expr(self._lazy_expr, None)
else:
self.device_buffer = force(self).device_buffer
self._lazy_expr = lazy_array(self.aval.shape)
force(self) # sets self.device_buffer and self._lazy_expr
self._npy_value = self.device_buffer.to_py()
self._npy_value.flags.writeable = False
return self._npy_value
Expand Down Expand Up @@ -987,16 +993,20 @@ def _instantiate_device_constant(const, device=None, backend=None, cutoff=1e6):
compiled = c.Build(xla_const).Compile((), opts, backend=xb.get_backend(backend))
return compiled.Execute(())

# To force a DeviceArray to be materialized, we just apply an identity primitive
def force(x):
if type(x) is not DeviceArray:
return x
lexpr = x._lazy_expr
if (type(lexpr.input) is LazyArrayVar and lexpr.dims == tuple(range(x.ndim))):
return x # trivial lazy expr
return x # trivial lazy expr, no need to force
else:
return apply_primitive(force_p, x, aval=x.aval)
out = apply_primitive(force_p, x, aval=x.aval) # apply identity primitive
# side-effect: set x.device_buffer and x._lazy_expr
x.device_buffer = out.device_buffer
x._lazy_expr = out._lazy_expr
return out

force_p = core.Primitive('force')
force_p.def_abstract_eval(lambda *args, **kwargs: kwargs['aval'])
translations[force_p] = lambda c, x, aval: x
pe.custom_partial_eval_rules[force_p] = lambda trace, x: x
14 changes: 3 additions & 11 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,17 +1569,14 @@ class LazyTest(jtu.JaxTestCase):
def _check_num_eager_computations(self, num):
xla_primitive_callable = xla.xla_primitive_callable
count = [0]

def primitive_callable_and_count(*args, **kwargs):
count[0] += 1
return xla_primitive_callable(*args, **kwargs)

try:
xla.xla_primitive_callable = primitive_callable_and_count
yield
finally:
xla.xla_primitive_callable = xla_primitive_callable

self.assertEqual(count[0], num)

def test_lazy_reshape_multiply(self):
Expand All @@ -1605,14 +1602,9 @@ def test_lazy_eye(self):
expected = onp.eye(3, dtype=onp.float32) + 5
self.assertAllClose(z, expected, check_dtypes=True)

def test_lazy_jit_arguments(self):
x = np.arange(int(1e12)) # will likely oom if materialized
ans = jit(lambda x: x[0])(x)
self.assertEqual(ans, 0)

def test_lazy_jit_closured_over_values(self):
x = np.arange(int(1e12)) # will likely oom if materialized
ans = jit(lambda y: (x + y)[1])(x)
def test_lazy_jit_closed_over_values(self):
y = np.arange(int(1e12)) # will likely oom if materialized
ans = jit(lambda x: (x + y)[1])(1)
self.assertEqual(ans, 2)

@parameterized.parameters(jtu.cases_from_list(range(10000)))
Expand Down

0 comments on commit dc93867

Please sign in to comment.