Skip to content

Commit

Permalink
remove _convert_element_type from public jax.lax module
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 10, 2022
1 parent caf094d commit 8f93629
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 45 deletions.
40 changes: 24 additions & 16 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -37,25 +37,26 @@
import opt_einsum

import jax
from jax import jit, custom_jvp
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.util import _wraps
from jax._src.numpy.vectorize import vectorize
from jax import custom_jvp, jit
from jax import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax import errors
from jax import lax
from jax.core import ShapedArray, DShapedArray, ConcreteArray, canonicalize_shape
from jax.config import config
from jax.interpreters import pxla
from jax import lax
from jax.tree_util import tree_leaves, tree_flatten, tree_map

from jax._src import device_array
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
from jax._src.lax import lax as lax_internal
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.util import _wraps
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
from jax.tree_util import tree_leaves, tree_flatten, tree_map

newaxis = None

Expand Down Expand Up @@ -252,7 +253,8 @@ def _promote_dtypes(*args):
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
return [lax_internal._convert_element_type(x, to_dtype, weak_type)
for x in args]

def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Expand All @@ -262,7 +264,8 @@ def _promote_dtypes_inexact(*args):
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args]
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]

def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
Expand Down Expand Up @@ -2815,7 +2818,8 @@ def _check_no_padding(axis_padding, mode):
def _pad_constant(array, pad_width, constant_values):
nd = ndim(array)
constant_values = broadcast_to(asarray(constant_values), (nd, 2))
constant_values = lax._convert_element_type(constant_values, array.dtype, dtypes.is_weakly_typed(array))
constant_values = lax_internal._convert_element_type(
constant_values, array.dtype, dtypes.is_weakly_typed(array))
for i in range(nd):
widths = [(0, 0, 0)] * nd
widths[i] = (pad_width[i, 0], 0, 0)
Expand Down Expand Up @@ -2926,7 +2930,8 @@ def _pad_linear_ramp(array, pad_width, end_values):
dtype=array.dtype,
axis=axis
)
ramp_before = lax._convert_element_type(ramp_before, weak_type=dtypes.is_weakly_typed(array))
ramp_before = lax_internal._convert_element_type(
ramp_before, weak_type=dtypes.is_weakly_typed(array))
ramp_after = linspace(
start=end_values[axis][1],
stop=edge_after.squeeze(axis), # Dimension is replaced by linspace
Expand All @@ -2935,7 +2940,8 @@ def _pad_linear_ramp(array, pad_width, end_values):
dtype=array.dtype,
axis=axis
)
ramp_after = lax._convert_element_type(ramp_after, weak_type=dtypes.is_weakly_typed(array))
ramp_after = lax_internal._convert_element_type(
ramp_after, weak_type=dtypes.is_weakly_typed(array))

# Reverse linear space in appropriate dimension
ramp_after = flip(ramp_after, axis)
Expand Down Expand Up @@ -2969,8 +2975,10 @@ def _pad_stats(array, pad_width, stat_length, stat_func):
stat_before = round(stat_before)
stat_after = round(stat_after)

stat_before = lax._convert_element_type(stat_before, array.dtype, dtypes.is_weakly_typed(array))
stat_after = lax._convert_element_type(stat_after, array.dtype, dtypes.is_weakly_typed(array))
stat_before = lax_internal._convert_element_type(
stat_before, array.dtype, dtypes.is_weakly_typed(array))
stat_after = lax_internal._convert_element_type(
stat_after, array.dtype, dtypes.is_weakly_typed(array))

npad_before, npad_after = pad_width[i]
pad_before = repeat(stat_before, npad_before, axis=i)
Expand Down Expand Up @@ -3400,7 +3408,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):

raise TypeError("Unexpected input type for array: {}".format(type(object)))

out = lax._convert_element_type(out, dtype, weak_type=weak_type)
out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out):
out = lax.expand_dims(out, range(ndmin - ndim(out)))
return out
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/numpy/linalg.py
Expand Up @@ -22,11 +22,12 @@

from jax import jit, custom_jvp
from jax import lax

from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src import dtypes
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.util import canonicalize_axis

_T = lambda x: jnp.swapaxes(x, -1, -2)
Expand All @@ -39,7 +40,8 @@ def _promote_arg_dtypes(*args):
if not jnp.issubdtype(dtype, jnp.inexact):
dtype, weak_type = jnp.float_, False
dtype = dtypes.canonicalize_dtype(dtype)
args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
args = [lax_internal._convert_element_type(arg, dtype, weak_type)
for arg in args]
if len(args) == 1:
return args[0]
else:
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/ops/scatter.py
Expand Up @@ -21,9 +21,11 @@

from jax import core
from jax import lax

from jax._src import dtypes
from jax._src.numpy import lax_numpy as jnp
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.numpy import lax_numpy as jnp


Array = Any
Expand Down Expand Up @@ -109,7 +111,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted,
unique_indices=indexer.unique_indices or unique_indices,
mode=mode)
return lax._convert_element_type(out, dtype, weak_type)
return lax_internal._convert_element_type(out, dtype, weak_type)



Expand Down
10 changes: 7 additions & 3 deletions jax/_src/scipy/sparse/linalg.py
Expand Up @@ -18,11 +18,14 @@

import numpy as np
import jax.numpy as jnp
from jax import device_put
from jax import lax
from jax import scipy as jsp
from jax import lax, device_put
from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,
tree_reduce, Partial)

from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.util import safe_map as map


Expand Down Expand Up @@ -170,7 +173,7 @@ def body_fun(value):
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_

r0 = _sub(b, A(x0))
rho0 = alpha0 = omega0 = lax._convert_element_type(
rho0 = alpha0 = omega0 = lax_internal._convert_element_type(
1, *dtypes._lattice_result_type(*tree_leaves(b)))
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)

Expand Down Expand Up @@ -519,7 +522,8 @@ def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
unit_residual,
)
dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b))
H = lax._convert_element_type(jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type)
H = lax_internal._convert_element_type(
jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type)

def loop_cond(carry):
_, _, breakdown, k = carry
Expand Down
1 change: 0 additions & 1 deletion jax/lax/__init__.py
Expand Up @@ -82,7 +82,6 @@
conj as conj,
conj_p as conj_p,
convert_element_type as convert_element_type,
_convert_element_type as _convert_element_type,
convert_element_type_p as convert_element_type_p,
cos as cos,
cos_p as cos_p,
Expand Down
9 changes: 5 additions & 4 deletions tests/dtypes_test.py
Expand Up @@ -24,9 +24,10 @@

import jax
from jax._src import dtypes
from jax import lax
from jax import numpy as jnp

from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -376,7 +377,7 @@ def testBinaryPromotionJitInvariance(self, xtype, ytype, xfun, yfun):
)
def testUnaryPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax._convert_element_type(0, dtype, weak_type=weak_type)
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
if weak_type:
expected = dtypes.canonicalize_dtype(
dtypes._default_types['f' if x.dtype == 'bfloat16' else x.dtype.kind])
Expand All @@ -392,7 +393,7 @@ def testUnaryPromotion(self, dtype, weak_type):
)
def testBinaryNonPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax._convert_element_type(0, dtype, weak_type=weak_type)
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
y = (x + x)
assert x.dtype == y.dtype
assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)
Expand All @@ -404,7 +405,7 @@ def testBinaryNonPromotion(self, dtype, weak_type):
for weak_type in [True, False]
)
def testDeviceArrayRepr(self, dtype, weak_type):
val = lax._convert_element_type(0, dtype, weak_type=weak_type)
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
rep = repr(val)
self.assertStartsWith(rep, 'DeviceArray(')
if weak_type:
Expand Down
5 changes: 3 additions & 2 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -27,12 +27,13 @@
import numpy as np

import jax
from jax import lax
from jax import numpy as jnp
from jax import ops

from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import lax as lax_internal

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -933,7 +934,7 @@ def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))

def testIndexingWeakTypes(self):
x = lax._convert_element_type(jnp.arange(5), int, weak_type=True)
x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True)

a = x.at[0].set(1.0)
self.assertEqual(a.dtype, x.dtype)
Expand Down
18 changes: 11 additions & 7 deletions tests/lax_numpy_test.py
Expand Up @@ -38,14 +38,16 @@
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src import device_array
from jax._src import dtypes
from jax import tree_util
from jax.test_util import check_grads
from jax._src.util import prod, safe_zip
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps

from jax._src import device_array
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import prod, safe_zip

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -3284,7 +3286,8 @@ def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type,
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
x = lax._convert_element_type(rng(shape, in_dtype), weak_type=weak_type)
x = lax_internal._convert_element_type(rng(shape, in_dtype),
weak_type=weak_type)
fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape)
expected_weak_type = weak_type and (out_dtype is None)
self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type)
Expand Down Expand Up @@ -3316,7 +3319,8 @@ def testArrayWeakType(self, funcname, input_type, val, dtype):
for slc in [slice(None), slice(0), slice(3), 0, ...]))
def testSliceWeakTypes(self, shape, dtype, weak_type, slc):
rng = jtu.rand_default(self.rng())
x = lax._convert_element_type(rng(shape, dtype), weak_type=weak_type)
x = lax_internal._convert_element_type(rng(shape, dtype),
weak_type=weak_type)
op = lambda x: x[slc]
self.assertEqual(op(x).aval.weak_type, weak_type)
self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type)
Expand Down
16 changes: 9 additions & 7 deletions tests/lax_test.py
Expand Up @@ -233,7 +233,7 @@ def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
def testConvertElementType(self, from_dtype, to_dtype, weak_type):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax._convert_element_type(x, to_dtype, weak_type)
op = lambda x: lax_internal._convert_element_type(x, to_dtype, weak_type)
self._CompileAndCheck(op, args_maker)

x = rng((1,), from_dtype)
Expand Down Expand Up @@ -288,8 +288,8 @@ def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype):
for weak_type in [True, False]))
def testBitcastConvertWeakType(self, from_dtype, to_dtype, weak_type):
rng = jtu.rand_default(self.rng())
x_in = lax._convert_element_type(rng((2, 3), from_dtype),
weak_type=weak_type)
x_in = lax_internal._convert_element_type(rng((2, 3), from_dtype),
weak_type=weak_type)
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self.assertEqual(dtypes.is_weakly_typed(x_in), weak_type)
x_out = op(x_in)
Expand Down Expand Up @@ -1741,8 +1741,9 @@ def testReduce(self, op, init_val, shape, dtype, dims):
for init_weak_type in [True, False]))
def testReduceWeakType(self, op_namespace, op, arr_weak_type, init_weak_type):
op = getattr(op_namespace, op)
arr = lax._convert_element_type(np.arange(10), int, weak_type=arr_weak_type)
init = lax._convert_element_type(1, int, weak_type=init_weak_type)
arr = lax_internal._convert_element_type(np.arange(10), int,
weak_type=arr_weak_type)
init = lax_internal._convert_element_type(1, int, weak_type=init_weak_type)
fun = lambda arr, init: lax.reduce(arr, init, op, (0,))
out = fun(arr, init)
self.assertEqual(dtypes.is_weakly_typed(out), arr_weak_type and init_weak_type)
Expand Down Expand Up @@ -2625,7 +2626,7 @@ def test_const(self, dtype, weak_type):
if dtype in set(python_scalar_types):
val = dtype(0)
else:
val = lax._convert_element_type(0, dtype, weak_type=weak_type)
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)

const = lax_internal._const(val, 0)
self.assertEqual(dtypes.dtype(val, canonicalize=True),
Expand Down Expand Up @@ -2810,7 +2811,8 @@ def testArgMinMaxInvalidAxisError(self, jax_fn):
for weak_type in [True, False]))
def testArgMinMaxWeakType(self, jax_fn, weak_type):
op = lambda x: jax_fn(x, axis=0, index_dtype=np.int32)
x_in = lax._convert_element_type(np.ones((2, 2)), weak_type=weak_type)
x_in = lax_internal._convert_element_type(np.ones((2, 2)),
weak_type=weak_type)
self.assertEqual(dtypes.is_weakly_typed(x_in), weak_type)
x_out = op(x_in)
self.assertEqual(dtypes.is_weakly_typed(x_out), False)
Expand Down

0 comments on commit 8f93629

Please sign in to comment.