Skip to content

Commit

Permalink
Don't call dtypes.result_type() unnecessarily on the type of an arr…
Browse files Browse the repository at this point in the history
…ay during abstractification.

Remove make_shaped_array since it has no more non-test users.

```
name        old cpu/op  new cpu/op  delta
device_put  69.4µs ± 6%  63.5µs ± 3%  -8.56%  (p=0.000 n=10+10)

name        old time/op             new time/op             delta
device_put  69.4µs ± 6%             63.5µs ± 3%  -8.56%        (p=0.000 n=10+10)
```

PiperOrigin-RevId: 491795793
  • Loading branch information
hawkinsp authored and jax authors committed Nov 30, 2022
1 parent 22f67d6 commit 6bda0d2
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
5 changes: 0 additions & 5 deletions jax/_src/abstract_arrays.py
Expand Up @@ -31,11 +31,6 @@
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped


def make_shaped_array(x):
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
return ShapedArray(np.shape(x), dtype)

def zeros_like_array(x):
dtype, weak_type = dtypes._lattice_result_type(x)
dtype = dtypes.canonicalize_dtype(dtype)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/dtypes.py
Expand Up @@ -444,6 +444,11 @@ def is_python_scalar(x: Any) -> bool:
except AttributeError:
return type(x) in python_scalar_dtypes

def check_valid_dtype(dtype: DType) -> None:
if dtype not in _jax_dtype_set:
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")

def dtype(x: Any, *, canonicalize: bool = False) -> DType:
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
if x is None:
Expand Down
18 changes: 15 additions & 3 deletions jax/interpreters/xla.py
Expand Up @@ -32,8 +32,7 @@
from jax._src import device_array
from jax._src import dtypes
from jax._src import source_info_util
from jax._src.abstract_arrays import (make_shaped_array, array_types,
numpy_scalar_types)
from jax._src.abstract_arrays import numpy_scalar_types
from jax.core import (ConcreteArray, ShapedArray, str_eqn_compact)
import jax._src.pretty_printer as pp
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
Expand Down Expand Up @@ -287,12 +286,25 @@ def _make_abstract_python_scalar(typ, val):
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)

def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))


pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
for t in device_array.device_array_types:
pytype_aval_mappings[t] = operator.attrgetter('aval')
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)

Expand Down
7 changes: 3 additions & 4 deletions tests/core_test.py
Expand Up @@ -37,7 +37,6 @@

from jax._src import util
from jax._src import test_util as jtu
from jax._src.abstract_arrays import make_shaped_array
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow

Expand Down Expand Up @@ -478,16 +477,16 @@ def new_jaxpr():

jaxpr = new_jaxpr()
# int, not float!
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
jaxpr.eqns[0].outvars[0].aval = core.ShapedArray((), jnp.dtype(jnp.int32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
lambda: core.check_jaxpr(jaxpr))

jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(
np.ones((2, 3), dtype=jnp.float32))
jaxpr.eqns[0].outvars[0].aval = core.ShapedArray((2, 3),
jnp.dtype(jnp.float32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
Expand Down

0 comments on commit 6bda0d2

Please sign in to comment.