Skip to content

Commit

Permalink
Rolling forward #12707 after rollback, due to changes in relatively t…
Browse files Browse the repository at this point in the history
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480229237
  • Loading branch information
mattjj authored and jax authors committed Oct 11, 2022
1 parent 9b3e864 commit df5f7cb
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 191 deletions.
50 changes: 29 additions & 21 deletions jax/_src/dispatch.py
Expand Up @@ -513,6 +513,10 @@ def lower_xla_callable(
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
closed_out_type = [
(a.update(shape=tuple(pe.InDBIdx(d.val - len(consts))
if type(d) is pe.InDBIdx else d for d in a.shape))
if type(a) is core.DShapedArray else a, b) for a, b in out_type]
module_name = f"jit_{fun.__name__}"
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
Expand All @@ -526,8 +530,8 @@ def lower_xla_callable(
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
return XlaComputation(
name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
name, module, False, donated_invars, fun.in_type, tuple(closed_out_type),
nreps=nreps, device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals,
has_unordered_effects=bool(unordered_effects),
ordered_effects=ordered_effects, kept_var_idx=kept_var_idx,
Expand Down Expand Up @@ -565,10 +569,21 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
return False

def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
return (any(type(v.aval) is core.AbstractBInt for v in jaxpr.invars) or
any(type(v.aval) is core.AbstractBInt
return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars
if isinstance(v.aval, core.UnshapedArray)) or
any(_is_bint_axis_size(d)
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
for e in j.eqns for v in e.outvars))
for e in j.eqns for v in e.outvars
if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape))

def _is_bint_axis_size(d: core.AxisSize) -> bool:
if isinstance(d, core.DArray):
assert not d.shape
return type(d.dtype) is core.bint
elif isinstance(d, core.Var):
return (isinstance(d.aval, core.DShapedArray) and
type(d.aval.dtype) is core.bint)
return False

def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
Expand Down Expand Up @@ -658,7 +673,6 @@ def aval_to_num_buffers(aval: core.AbstractValue) -> int:
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.DShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
num_buffers_handlers[core.AbstractBInt] = lambda _: 1


def _input_handler(backend: Backend,
Expand Down Expand Up @@ -776,7 +790,7 @@ def aval_to_result_handler(sticky_device: Optional[Device],

def array_result_handler(sticky_device: Optional[Device],
aval: core.ShapedArray):
if aval.dtype == dtypes.float0:
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
if core.is_opaque_dtype(aval.dtype):
Expand All @@ -787,7 +801,7 @@ def array_result_handler(sticky_device: Optional[Device],

def dynamic_array_result_handler(sticky_device: Optional[Device],
aval: core.DShapedArray):
if aval.dtype == dtypes.float0:
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
else:
return partial(_dynamic_array_result_handler, sticky_device, aval)
Expand All @@ -797,17 +811,14 @@ def _dynamic_array_result_handler(sticky_device, aval, env, buf):
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
if all(type(d) is int for d in shape):
aval = core.ShapedArray(tuple(shape), aval.dtype)
if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
aval = core.ShapedArray(tuple(shape), buf.dtype)
return maybe_create_array_from_da(buf, aval, sticky_device)
elif any(type(d) is core.BInt for d in shape):
padded_shape = [d.bound if type(d) is core.BInt else d for d in shape]
buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type)
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.PaddedArray(aval.update(shape=tuple(shape)), data)
else:
aval = core.ShapedArray(tuple(shape), aval.dtype)
return maybe_create_array_from_da(buf, aval, sticky_device)
pad_shape = [d.dtype.bound if _is_bint_axis_size(d) else d for d in shape]
buf_aval = core.ShapedArray(tuple(pad_shape), buf.dtype, aval.weak_type)
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.DArray(aval.update(shape=tuple(shape)), data)



Expand All @@ -818,8 +829,6 @@ def _dynamic_array_result_handler(sticky_device, aval, env, buf):
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
result_handlers[core.AbstractBInt] = \
lambda _, a: lambda _, b: core.BInt(int(b), a.bound)


def needs_check_special():
Expand Down Expand Up @@ -1228,15 +1237,14 @@ def _device_put_token(_, device):
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
device_put_handlers[core.Token] = _device_put_token
device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d)


def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
x = _copy_device_array_to_device(x, device)
return (x.device_buffer,)
for t in device_array.device_array_types:
device_put_handlers[t] = _device_put_device_array
device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d)
device_put_handlers[core.DArray] = lambda x, d: device_put(x._data, d)

def _copy_device_array_to_device(
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],
Expand Down
9 changes: 7 additions & 2 deletions jax/_src/dtypes.py
Expand Up @@ -237,7 +237,7 @@ def _issubclass(a, b):
except TypeError:
return False

def issubdtype(a, b):
def issubdtype(a, b) -> bool:
if a == "bfloat16":
a = bfloat16
if a == bfloat16:
Expand All @@ -251,7 +251,10 @@ def issubdtype(a, b):
# interacts badly with JAX's custom scalar types. As a workaround,
# explicitly cast the second argument to a NumPy type object.
b = np.dtype(b).type
return np.issubdtype(a, b)
try:
return np.issubdtype(a, b)
except TypeError: # e.g. if 'a' is not a np.dtype
return False

can_cast = np.can_cast
issubsctype = np.issubsctype
Expand Down Expand Up @@ -436,6 +439,8 @@ def dtype(x, *, canonicalize=False):
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)):
dt = x.dtype
else:
dt = np.result_type(x)
if dt not in _jax_dtype_set:
Expand Down

0 comments on commit df5f7cb

Please sign in to comment.