Skip to content

Commit

Permalink
add lazy sub-language, fuse into op-by-op computations
Browse files Browse the repository at this point in the history
Also removes the DeviceConstant system.
  • Loading branch information
mattjj committed Dec 5, 2019
1 parent 0c0137d commit 5a5db5f
Show file tree
Hide file tree
Showing 12 changed files with 583 additions and 331 deletions.
8 changes: 6 additions & 2 deletions jax/core.py
Expand Up @@ -116,14 +116,18 @@ def __init__(self, val):
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError):
except (TypeError, AttributeError, ValueError):
self.hash = None

def __hash__(self):
return id(self.val) if self.hash is None else self.hash

def __eq__(self, other):
return self.val is other.val if self.hash is None else self.val == other.val
if self.hash is None:
return self.val is other.val
else:
return (self.val == other.val or
self.val != self.val and other.val != other.val) # nans are equal

def __repr__(self):
if self.hash is None:
Expand Down
20 changes: 14 additions & 6 deletions jax/interpreters/pxla.py
Expand Up @@ -377,11 +377,14 @@ def _shard_sharded_device_array(x, devices, assignments):
return (xla.device_put(x[assignments[r]], devices[r]) for r in range(n))
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array

def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array
xla.pytype_aval_mappings[ShardedDeviceArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
xb.register_constant_handler(ShardedDeviceArray, xla._device_array_constant_handler)


class ChunkedDeviceArray(ShardedDeviceArray):
Expand Down Expand Up @@ -414,9 +417,10 @@ def xla_pmap_impl(fun, *args, **params):
backend = params.pop('backend', None)
assert not params

abstract_args = map(xla.abstractify, args)
# TODO(mattjj): support laziness here, don't just force every argument
avals = [xla.abstractify(xla.force(x)) for x in args]
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size, devices,
*abstract_args)
*avals)
return compiled_fun(*args)

@lu.cache
Expand Down Expand Up @@ -449,8 +453,8 @@ def dynamic_fun(dummy, *args):
with extend_dynamic_axis_env(axis_name, dummy.trace, global_axis_size):
return fun.call_wrapped(*args)

avals = tuple(map(partial(shard_aval, axis_size), avals))
pvals = [PartialVal((aval, core.unit)) for aval in avals]
sharded_avals = [shard_aval(axis_size, aval) for aval in avals]
pvals = [PartialVal((aval, core.unit)) for aval in sharded_avals]
pval = PartialVal([core.abstract_unit, core.unit]) # dummy value for axis env
with core.new_master(JaxprTrace, True) as master:
jaxpr, (out_pvals, consts, env) = \
Expand Down Expand Up @@ -478,7 +482,7 @@ def dynamic_fun(dummy, *args):

c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
xla_consts = _map(c.Constant, consts)
xla_args = xla._xla_callable_args(c, avals, tuple_args)
xla_args = _pmap_callable_args(c, sharded_avals, tuple_args)
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, (), *xla_args)
built = c.Build(c.Tuple(*out_nodes))

Expand Down Expand Up @@ -517,6 +521,10 @@ def dynamic_fun(dummy, *args):
class ResultToPopulate(object): pass
result_to_populate = ResultToPopulate()

def _pmap_callable_args(c, avals, tuple_args):
# TODO(mattjj): support laziness for broadcasted axes to map
return xla._xla_callable_args(c, avals, tuple_args)

def _pvals_to_results_handler(size, nrep, out_pvals):
nouts = len(out_pvals)
handlers = [_pval_to_result_handler(size, nrep, pval) for pval in out_pvals]
Expand Down

0 comments on commit 5a5db5f

Please sign in to comment.