Skip to content

Commit

Permalink
implement bint arrays (opaque dtypes), add padding rules
Browse files Browse the repository at this point in the history
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
  • Loading branch information
mattjj and sharadmv committed Oct 9, 2022
1 parent 2693afa commit 6d2aaac
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 183 deletions.
46 changes: 27 additions & 19 deletions jax/_src/dispatch.py
Expand Up @@ -513,6 +513,10 @@ def lower_xla_callable(
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
closed_out_type = [
(a.update(shape=tuple(pe.InDBIdx(d.val - len(consts))
if type(d) is pe.InDBIdx else d for d in a.shape))
if type(a) is core.DShapedArray else a, b) for a, b in out_type]
module_name = f"jit_{fun.__name__}"
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
Expand All @@ -526,8 +530,8 @@ def lower_xla_callable(
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
return XlaComputation(
name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
name, module, False, donated_invars, fun.in_type, tuple(closed_out_type),
nreps=nreps, device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals,
has_unordered_effects=bool(unordered_effects),
ordered_effects=ordered_effects, kept_var_idx=kept_var_idx,
Expand Down Expand Up @@ -565,10 +569,21 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
return False

def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
return (any(type(v.aval) is core.AbstractBInt for v in jaxpr.invars) or
any(type(v.aval) is core.AbstractBInt
return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars
if isinstance(v.aval, core.UnshapedArray)) or
any(_is_bint_axis_size(d)
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
for e in j.eqns for v in e.outvars))
for e in j.eqns for v in e.outvars
if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape))

def _is_bint_axis_size(d: core.AxisSize) -> bool:
if isinstance(d, core.DArray):
assert not d.shape
return type(d.dtype) is core.bint
elif isinstance(d, core.Var):
return (isinstance(d.aval, core.DShapedArray) and
type(d.aval.dtype) is core.bint)
return False

def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
Expand Down Expand Up @@ -658,7 +673,6 @@ def aval_to_num_buffers(aval: core.AbstractValue) -> int:
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.DShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
num_buffers_handlers[core.AbstractBInt] = lambda _: 1


def _input_handler(backend: Backend,
Expand Down Expand Up @@ -797,17 +811,14 @@ def _dynamic_array_result_handler(sticky_device, aval, env, buf):
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
if all(type(d) is int for d in shape):
aval = core.ShapedArray(tuple(shape), aval.dtype)
if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
aval = core.ShapedArray(tuple(shape), buf.dtype)
return maybe_create_array_from_da(buf, aval, sticky_device)
elif any(type(d) is core.BInt for d in shape):
padded_shape = [d.bound if type(d) is core.BInt else d for d in shape]
buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type)
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.PaddedArray(aval.update(shape=tuple(shape)), data)
else:
aval = core.ShapedArray(tuple(shape), aval.dtype)
return maybe_create_array_from_da(buf, aval, sticky_device)
pad_shape = [d.dtype.bound if _is_bint_axis_size(d) else d for d in shape]
buf_aval = core.ShapedArray(tuple(pad_shape), buf.dtype, aval.weak_type)
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.DArray(aval.update(shape=tuple(shape)), data)



Expand All @@ -818,8 +829,6 @@ def _dynamic_array_result_handler(sticky_device, aval, env, buf):
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
result_handlers[core.AbstractBInt] = \
lambda _, a: lambda _, b: core.BInt(int(b), a.bound)


def needs_check_special():
Expand Down Expand Up @@ -1228,15 +1237,14 @@ def _device_put_token(_, device):
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
device_put_handlers[core.Token] = _device_put_token
device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d)


def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
x = _copy_device_array_to_device(x, device)
return (x.device_buffer,)
for t in device_array.device_array_types:
device_put_handlers[t] = _device_put_device_array
device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d)
device_put_handlers[core.DArray] = lambda x, d: device_put(x._data, d)

def _copy_device_array_to_device(
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],
Expand Down
9 changes: 7 additions & 2 deletions jax/_src/dtypes.py
Expand Up @@ -237,7 +237,7 @@ def _issubclass(a, b):
except TypeError:
return False

def issubdtype(a, b):
def issubdtype(a, b) -> bool:
if a == "bfloat16":
a = bfloat16
if a == bfloat16:
Expand All @@ -251,7 +251,10 @@ def issubdtype(a, b):
# interacts badly with JAX's custom scalar types. As a workaround,
# explicitly cast the second argument to a NumPy type object.
b = np.dtype(b).type
return np.issubdtype(a, b)
try:
return np.issubdtype(a, b)
except TypeError: # e.g. if 'a' is not a np.dtype
return False

can_cast = np.can_cast
issubsctype = np.issubsctype
Expand Down Expand Up @@ -436,6 +439,8 @@ def dtype(x, *, canonicalize=False):
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)):
dt = x.dtype
else:
dt = np.result_type(x)
if dt not in _jax_dtype_set:
Expand Down

0 comments on commit 6d2aaac

Please sign in to comment.