Skip to content

Commit

Permalink
roll-forward #11952
Browse files Browse the repository at this point in the history
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
  • Loading branch information
froystig authored and jax authors committed Aug 20, 2022
1 parent 78cfbeb commit 9789e83
Show file tree
Hide file tree
Showing 18 changed files with 1,127 additions and 261 deletions.
11 changes: 10 additions & 1 deletion jax/_src/api.py
Expand Up @@ -1107,6 +1107,9 @@ def _check_scalar(x):
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
raise TypeError(
f"{name} with input element type {core.aval_eltype(aval).name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
Expand All @@ -1125,6 +1128,9 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):

def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
raise TypeError(
f"{name} with output element type {core.aval_eltype(aval).name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
Expand Down Expand Up @@ -1202,6 +1208,9 @@ def jacfun(*args, **kwargs):
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
raise TypeError(
f"jacfwd with input element type {core.aval_eltype(aval).name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
Expand Down Expand Up @@ -2957,7 +2966,7 @@ class ShapeDtypeStruct:
__slots__ = ["shape", "dtype", "named_shape"]
def __init__(self, shape, dtype, named_shape=None):
self.shape = shape
self.dtype = np.dtype(dtype)
self.dtype = dtype if core.is_custom_eltype(dtype) else np.dtype(dtype)
self.named_shape = {} if named_shape is None else dict(named_shape)

size = property(lambda self: prod(self.shape))
Expand Down
7 changes: 6 additions & 1 deletion jax/_src/dispatch.py
Expand Up @@ -112,9 +112,14 @@ def apply_primitive(prim, *args, **params):
**params)
return compiled_fun(*args)

# TODO(phawkins): update code referring to xla.apply_primitive to point here.
# TODO(phawkins,frostig,mattjj): update code referring to
# xla.apply_primitive to point here, or use simple_impl if that's why
# it is using apply_primitive to begin with
xla.apply_primitive = apply_primitive

def simple_impl(prim):
prim.def_impl(partial(apply_primitive, prim))

RuntimeToken = Any

class RuntimeTokenSet(threading.local):
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -1574,9 +1574,12 @@ def _pred_bcast_select_mhlo(
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
x_y_type = mlir.aval_to_ir_type(x_y_aval)
bcast_pred_type = ir.RankedTensorType.get(
x_y_type.shape, mlir.dtype_to_ir_type(np.dtype(np.bool_)))
bcast_pred = mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
bcast_pred_type, pred,
mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results

### fori_loop
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/lax/lax.py
Expand Up @@ -1239,11 +1239,14 @@ def stop_gradient(x: T) -> T:
DeviceArray(0., dtype=float32, weak_type=True)
"""
def stop(x):
if (dtypes.issubdtype(_dtype(x), np.floating) or
# only bind primitive on inexact dtypes, to avoid some staging
if core.has_custom_eltype(x):
return x
elif (dtypes.issubdtype(_dtype(x), np.floating) or
dtypes.issubdtype(_dtype(x), np.complexfloating)):
return ad_util.stop_gradient_p.bind(x)
else:
return x # only bind primitive on inexact dtypes, to avoid some staging
return x
return tree_map(stop, x)

def reduce_precision(operand: Union[float, Array],
Expand Down Expand Up @@ -1504,7 +1507,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
return result_dtype(*avals)


def _broadcasting_shape_rule(name, *avals):
def broadcasting_shape_rule(name, *avals):
shapes = [aval.shape for aval in avals if aval.shape]
if not shapes:
return ()
Expand Down Expand Up @@ -1545,7 +1548,7 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):

def naryop(result_dtype, accepted_dtypes, name):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
shape_rule = partial(_broadcasting_shape_rule, name)
shape_rule = partial(broadcasting_shape_rule, name)
weak_type_rule = partial(_naryop_weak_type_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
weak_type_rule=weak_type_rule)
Expand Down
38 changes: 21 additions & 17 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -527,7 +527,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None):

@_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a, axes=None):
_check_arraylike("transpose", a)
_stackable(a) or _check_arraylike("transpose", a)
axes = np.arange(ndim(a))[::-1] if axes is None else axes
return lax.transpose(a, axes)

Expand Down Expand Up @@ -5107,27 +5107,31 @@ def _set_shaped_array_attributes(shaped_array):
_set_shaped_array_attributes(DShapedArray)


def _set_device_array_base_attributes(device_array):
def _set_device_array_base_attributes(device_array, include=None):
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
def maybe_setattr(attr_name, target):
if not include or attr_name in include:
setattr(device_array, attr_name, target)

for operator_name, function in _operators.items():
setattr(device_array, f"__{operator_name}__", function)
maybe_setattr(f"__{operator_name}__", function)
for method_name in _nondiff_methods + _diff_methods:
setattr(device_array, method_name, globals()[method_name])
maybe_setattr(method_name, globals()[method_name])
# TODO(jakevdp): remove tile method after August 2022
setattr(device_array, "tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
setattr(device_array, "reshape", _reshape)
setattr(device_array, "transpose", _transpose)
setattr(device_array, "flatten", ravel)
setattr(device_array, "flat", property(_notimplemented_flat))
setattr(device_array, "T", property(transpose))
setattr(device_array, "real", property(real))
setattr(device_array, "imag", property(imag))
setattr(device_array, "astype", _astype)
setattr(device_array, "view", _view)
setattr(device_array, "nbytes", property(_nbytes))
setattr(device_array, "itemsize", property(_itemsize))
setattr(device_array, "clip", _clip)
maybe_setattr("tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
maybe_setattr("reshape", _reshape)
maybe_setattr("transpose", _transpose)
maybe_setattr("flatten", ravel)
maybe_setattr("flat", property(_notimplemented_flat))
maybe_setattr("T", property(transpose))
maybe_setattr("real", property(real))
maybe_setattr("imag", property(imag))
maybe_setattr("astype", _astype)
maybe_setattr("view", _view)
maybe_setattr("nbytes", property(_nbytes))
maybe_setattr("itemsize", property(_itemsize))
maybe_setattr("clip", _clip)

_set_device_array_base_attributes(device_array.DeviceArray)
_set_device_array_base_attributes(Array)
Expand Down

0 comments on commit 9789e83

Please sign in to comment.