From 1fedf04ed538301655380c4241ce6f21104ce147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Wed, 9 Aug 2023 13:41:34 +0200 Subject: [PATCH] API: Remove NINF and PINF usages --- jax/_src/lax/ann.py | 2 +- jax/_src/numpy/lax_numpy.py | 6 +- jax/experimental/jax2tf/shape_poly.py | 14 ++-- .../jax2tf/tests/jax2tf_limitations.py | 4 +- .../jax2tf/tests/shape_poly_test.py | 70 +++++++++---------- jax/numpy/__init__.py | 6 +- 6 files changed, 51 insertions(+), 51 deletions(-) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 6ee9a35efd27..059597812144 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -255,7 +255,7 @@ def _comparator_builder(op_type, is_max_k): def _get_init_val_literal(op_type, is_max_k): - return np.array(np.NINF if is_max_k else np.Inf, dtype=op_type) + return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index cc22ca355566..fe0e2af9ff59 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -100,9 +100,9 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: e = np.e euler_gamma = np.euler_gamma inf = np.inf -NINF = np.NINF -PZERO = np.PZERO -NZERO = np.NZERO +NINF = -np.inf # TODO: removed in Numpy 1.26 +PZERO = 0.0 # TODO: removed in Numpy 1.26 +NZERO = -0.0 # TODO: removed in Numpy 1.26 nan = np.nan # NumPy utility functions diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index 0414b2a7db23..494833ee61c2 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -190,18 +190,18 @@ def __lt__(self, other: '_DimAtom'): def bounds(self) -> tuple[float, float]: """Returns the lower and upper bounds, or -+ inf.""" if self.var is not None: - return (1, np.PINF) # variables are assumed to be >= 1 + return (1, np.inf) # variables are assumed to be >= 1 opnd_bounds = [opnd.bounds() for opnd in self.operands] if self.operation == _DimAtom.FLOORDIV: # a // b (a_l, a_u), (b_l, b_u) = opnd_bounds def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf assert b != 0 if not np.isinf(b): # divisor is finite - return math.floor(a / b) if not np.isinf(a) else np.NINF if (a >= 0) != (b >= 0) else np.PINF + return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf elif not np.isinf(a): # dividend is finite and divisor is infinite return -1 if (a >= 0) != (b >= 0) else 0 else: # both dividend and divisor are infinite - return np.NINF if (a >= 0) != (b >= 0) else np.PINF + return -np.inf if (a >= 0) != (b >= 0) else np.inf # Same reasoning as for multiplication: the bounds are among the cross-product # of the bounds. @@ -216,7 +216,7 @@ def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf elif b_u < 0: # negative divisor return (b_l + 1, 0) else: - return (np.NINF, np.PINF) + return (-np.inf, np.inf) elif self.operation == _DimAtom.NON_NEGATIVE: (b_l, b_h), = opnd_bounds @@ -668,7 +668,7 @@ def bounds(self) -> tuple[float, float]: lb = lb + min(item_l, item_u) # type: ignore ub = ub + max(item_l, item_u) # type: ignore - if lb != np.NINF or ub != np.PINF: + if lb != -np.inf or ub != np.inf: return lb, ub # Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0 # TODO(necula): add more principled support for floordiv and mod @@ -682,9 +682,9 @@ def bounds(self) -> tuple[float, float]: except InconclusiveDimensionOperation: continue if dec.factor > 0: - return (np.NINF, -1) + return (-np.inf, -1) else: - return (1, np.PINF) + return (1, np.inf) return lb, ub diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index cb0c354384a7..1f305b3e55b2 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1208,10 +1208,10 @@ def compare_reconstructed_operand(r_jax, r_tf, tol): # should also consider the gap between it and zero. Note that this code # relies on the singular values being in descending order. def compute_absolute_gap(s, m, n): - forward_appendant = np.Inf if m == n else 0 + forward_appendant = np.inf if m == n else 0 forward_diff = jnp.diff(s, axis=-1, append=forward_appendant) backward_diff = jnp.diff( - s[..., ::-1], axis=-1, append=np.Inf)[..., ::-1] + s[..., ::-1], axis=-1, append=np.inf)[..., ::-1] absolute_gap = jnp.minimum(jnp.abs(forward_diff), jnp.abs(backward_diff)) return absolute_gap diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index d5b645551426..1f22a6589349 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -234,52 +234,52 @@ def test_poly_bounds(self): bounded_le4 = 5 - a bounded_ge2 = b + 1 bounded_ge0_le4 = a % 5 - self.assertEqual(a.bounds(), (1, np.PINF)) - self.assertEqual(bounded_le4.bounds(), (np.NINF, 4)) - self.assertEqual(bounded_ge2.bounds(), (2, np.PINF)) + self.assertEqual(a.bounds(), (1, np.inf)) + self.assertEqual(bounded_le4.bounds(), (-np.inf, 4)) + self.assertEqual(bounded_ge2.bounds(), (2, np.inf)) self.assertEqual(bounded_ge0_le4.bounds(), (0, 4)) # Additions - self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (np.NINF, 8)) - self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.PINF)) - self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (np.NINF, np.PINF)) + self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 8)) + self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.inf)) + self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (-np.inf, np.inf)) # Subtractions - self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.PINF)) - self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (np.NINF, 4)) - self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (np.NINF, 2)) - self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.PINF)) - self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (np.NINF, 2)) - self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.PINF)) + self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.inf)) + self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 4)) + self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (-np.inf, 2)) + self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.inf)) + self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (-np.inf, 2)) + self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.inf)) # Multiplications - self.assertEqual((2 * a - 3).bounds(), (-1, np.PINF)) - self.assertEqual((-2 * a - 3).bounds(), (np.NINF, -5)) - self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.PINF)) - self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (np.NINF, np.PINF)) - self.assertEqual((a + b - a * b + a * b * a).bounds(), (np.NINF, np.PINF)) - self.assertEqual((a + 2 * b - a).bounds(), (2, np.PINF)) - self.assertEqual((a + 2 * b - a).bounds(), (2, np.PINF)) + self.assertEqual((2 * a - 3).bounds(), (-1, np.inf)) + self.assertEqual((-2 * a - 3).bounds(), (-np.inf, -5)) + self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.inf)) + self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (-np.inf, np.inf)) + self.assertEqual((a + b - a * b + a * b * a).bounds(), (-np.inf, np.inf)) + self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf)) + self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf)) # mod self.assertEqual(((b + 1) % 2).bounds(), (0, 1)) self.assertEqual(((b + 1) % -2).bounds(), (-1, 0)) self.assertEqual(((b - 4) % 2).bounds(), (0, 1)) - self.assertEqual(((b + 1) % a).bounds(), (0, np.PINF)) - self.assertEqual((11 % (a + 1)).bounds(), (0, np.PINF)) - self.assertEqual((-11 % (a + 1)).bounds(), (0, np.PINF)) - self.assertEqual((b % (a - 2)).bounds(), (np.NINF, np.PINF)) + self.assertEqual(((b + 1) % a).bounds(), (0, np.inf)) + self.assertEqual((11 % (a + 1)).bounds(), (0, np.inf)) + self.assertEqual((-11 % (a + 1)).bounds(), (0, np.inf)) + self.assertEqual((b % (a - 2)).bounds(), (-np.inf, np.inf)) # floordiv - self.assertEqual(((a + 4) // 2).bounds(), (2, np.PINF)) - self.assertEqual(((a + 4) // -2).bounds(), (np.NINF, -3)) - self.assertEqual(((a + 5) // 2).bounds(), (3, np.PINF)) - self.assertEqual(((a + 5) // -2).bounds(), (np.NINF, -3)) + self.assertEqual(((a + 4) // 2).bounds(), (2, np.inf)) + self.assertEqual(((a + 4) // -2).bounds(), (-np.inf, -3)) + self.assertEqual(((a + 5) // 2).bounds(), (3, np.inf)) + self.assertEqual(((a + 5) // -2).bounds(), (-np.inf, -3)) self.assertEqual((11 // (a + 1)).bounds(), (0, 5)) self.assertEqual((-11 // (a + 1)).bounds(), (-6, -1)) self.assertEqual((-11 // (- a)).bounds(), (0, 11)) # finite negative dividend, infinite divisor - self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.PINF)) - self.assertEqual((-b // (a + 1)).bounds(), (np.NINF, -1)) + self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.inf)) + self.assertEqual((-b // (a + 1)).bounds(), (-np.inf, -1)) # Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2 # and then evaluate them for a = 1, 5, 10000 @@ -301,14 +301,14 @@ def test_poly_bounds(self): # Bounds involving mod and floordiv self.assertEqual((5 - a % 5).bounds(), (1, 5)) self.assertEqual((-5 - a % (-5)).bounds(), (-5, -1)) - self.assertEqual((a - 5 % a).bounds(), (1, np.PINF)) - self.assertEqual((a - 5 % a).bounds(), (1, np.PINF)) - self.assertEqual((3 * (a + b) - 5 % (3 * (a + b))).bounds(), (1, np.PINF)) - self.assertEqual((- a + (b - 5) % a).bounds(), (np.NINF, -1)) + self.assertEqual((a - 5 % a).bounds(), (1, np.inf)) + self.assertEqual((a - 5 % a).bounds(), (1, np.inf)) + self.assertEqual((3 * (a + b) - 5 % (3 * (a + b))).bounds(), (1, np.inf)) + self.assertEqual((- a + (b - 5) % a).bounds(), (-np.inf, -1)) # non_negative - self.assertEqual(core.non_negative_dim(a).bounds(), (1, np.PINF)) - self.assertEqual(core.non_negative_dim(a - 5).bounds(), (0, np.PINF)) + self.assertEqual(core.non_negative_dim(a).bounds(), (1, np.inf)) + self.assertEqual(core.non_negative_dim(a - 5).bounds(), (0, np.inf)) self.assertEqual(core.non_negative_dim(15 - a).bounds(), (0, 14)) self.assertEqual((core.non_negative_dim(15 - a) // 3).bounds(), (0, 4)) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 78681100682f..da9b3a5500b6 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -22,9 +22,9 @@ from jax._src.numpy.lax_numpy import ( ComplexWarning as ComplexWarning, - NINF as NINF, - NZERO as NZERO, - PZERO as PZERO, + NINF as NINF, # TODO: removed in Numpy 1.26 + NZERO as NZERO, # TODO: removed in Numpy 1.26 + PZERO as PZERO, # TODO: removed in Numpy 1.26 allclose as allclose, angle as angle, append as append,