Skip to content

Commit

Permalink
API: Remove NINF and PINF usages
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Aug 9, 2023
1 parent 1bd5fd2 commit 1fedf04
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 51 deletions.
2 changes: 1 addition & 1 deletion jax/_src/lax/ann.py
Expand Up @@ -255,7 +255,7 @@ def _comparator_builder(op_type, is_max_k):


def _get_init_val_literal(op_type, is_max_k):
return np.array(np.NINF if is_max_k else np.Inf, dtype=op_type)
return np.array(-np.inf if is_max_k else np.inf, dtype=op_type)

def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension, recall_target, is_max_k,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -100,9 +100,9 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
NINF = np.NINF
PZERO = np.PZERO
NZERO = np.NZERO
NINF = -np.inf # TODO: removed in Numpy 1.26
PZERO = 0.0 # TODO: removed in Numpy 1.26
NZERO = -0.0 # TODO: removed in Numpy 1.26
nan = np.nan

# NumPy utility functions
Expand Down
14 changes: 7 additions & 7 deletions jax/experimental/jax2tf/shape_poly.py
Expand Up @@ -190,18 +190,18 @@ def __lt__(self, other: '_DimAtom'):
def bounds(self) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+ inf."""
if self.var is not None:
return (1, np.PINF) # variables are assumed to be >= 1
return (1, np.inf) # variables are assumed to be >= 1
opnd_bounds = [opnd.bounds() for opnd in self.operands]
if self.operation == _DimAtom.FLOORDIV: # a // b
(a_l, a_u), (b_l, b_u) = opnd_bounds
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
assert b != 0
if not np.isinf(b): # divisor is finite
return math.floor(a / b) if not np.isinf(a) else np.NINF if (a >= 0) != (b >= 0) else np.PINF
return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf
elif not np.isinf(a): # dividend is finite and divisor is infinite
return -1 if (a >= 0) != (b >= 0) else 0
else: # both dividend and divisor are infinite
return np.NINF if (a >= 0) != (b >= 0) else np.PINF
return -np.inf if (a >= 0) != (b >= 0) else np.inf

# Same reasoning as for multiplication: the bounds are among the cross-product
# of the bounds.
Expand All @@ -216,7 +216,7 @@ def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
elif b_u < 0: # negative divisor
return (b_l + 1, 0)
else:
return (np.NINF, np.PINF)
return (-np.inf, np.inf)

elif self.operation == _DimAtom.NON_NEGATIVE:
(b_l, b_h), = opnd_bounds
Expand Down Expand Up @@ -668,7 +668,7 @@ def bounds(self) -> tuple[float, float]:
lb = lb + min(item_l, item_u) # type: ignore
ub = ub + max(item_l, item_u) # type: ignore

if lb != np.NINF or ub != np.PINF:
if lb != -np.inf or ub != np.inf:
return lb, ub
# Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0
# TODO(necula): add more principled support for floordiv and mod
Expand All @@ -682,9 +682,9 @@ def bounds(self) -> tuple[float, float]:
except InconclusiveDimensionOperation:
continue
if dec.factor > 0:
return (np.NINF, -1)
return (-np.inf, -1)
else:
return (1, np.PINF)
return (1, np.inf)

return lb, ub

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Expand Up @@ -1208,10 +1208,10 @@ def compare_reconstructed_operand(r_jax, r_tf, tol):
# should also consider the gap between it and zero. Note that this code
# relies on the singular values being in descending order.
def compute_absolute_gap(s, m, n):
forward_appendant = np.Inf if m == n else 0
forward_appendant = np.inf if m == n else 0
forward_diff = jnp.diff(s, axis=-1, append=forward_appendant)
backward_diff = jnp.diff(
s[..., ::-1], axis=-1, append=np.Inf)[..., ::-1]
s[..., ::-1], axis=-1, append=np.inf)[..., ::-1]
absolute_gap = jnp.minimum(jnp.abs(forward_diff),
jnp.abs(backward_diff))
return absolute_gap
Expand Down
70 changes: 35 additions & 35 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -234,52 +234,52 @@ def test_poly_bounds(self):
bounded_le4 = 5 - a
bounded_ge2 = b + 1
bounded_ge0_le4 = a % 5
self.assertEqual(a.bounds(), (1, np.PINF))
self.assertEqual(bounded_le4.bounds(), (np.NINF, 4))
self.assertEqual(bounded_ge2.bounds(), (2, np.PINF))
self.assertEqual(a.bounds(), (1, np.inf))
self.assertEqual(bounded_le4.bounds(), (-np.inf, 4))
self.assertEqual(bounded_ge2.bounds(), (2, np.inf))
self.assertEqual(bounded_ge0_le4.bounds(), (0, 4))

# Additions
self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (np.NINF, 8))
self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.PINF))
self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (np.NINF, np.PINF))
self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 8))
self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.inf))
self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (-np.inf, np.inf))

# Subtractions
self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.PINF))
self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (np.NINF, 4))
self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (np.NINF, 2))
self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.PINF))
self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (np.NINF, 2))
self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.PINF))
self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.inf))
self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 4))
self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (-np.inf, 2))
self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.inf))
self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (-np.inf, 2))
self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.inf))

# Multiplications
self.assertEqual((2 * a - 3).bounds(), (-1, np.PINF))
self.assertEqual((-2 * a - 3).bounds(), (np.NINF, -5))
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.PINF))
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (np.NINF, np.PINF))
self.assertEqual((a + b - a * b + a * b * a).bounds(), (np.NINF, np.PINF))
self.assertEqual((a + 2 * b - a).bounds(), (2, np.PINF))
self.assertEqual((a + 2 * b - a).bounds(), (2, np.PINF))
self.assertEqual((2 * a - 3).bounds(), (-1, np.inf))
self.assertEqual((-2 * a - 3).bounds(), (-np.inf, -5))
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.inf))
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (-np.inf, np.inf))
self.assertEqual((a + b - a * b + a * b * a).bounds(), (-np.inf, np.inf))
self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf))
self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf))

# mod
self.assertEqual(((b + 1) % 2).bounds(), (0, 1))
self.assertEqual(((b + 1) % -2).bounds(), (-1, 0))
self.assertEqual(((b - 4) % 2).bounds(), (0, 1))
self.assertEqual(((b + 1) % a).bounds(), (0, np.PINF))
self.assertEqual((11 % (a + 1)).bounds(), (0, np.PINF))
self.assertEqual((-11 % (a + 1)).bounds(), (0, np.PINF))
self.assertEqual((b % (a - 2)).bounds(), (np.NINF, np.PINF))
self.assertEqual(((b + 1) % a).bounds(), (0, np.inf))
self.assertEqual((11 % (a + 1)).bounds(), (0, np.inf))
self.assertEqual((-11 % (a + 1)).bounds(), (0, np.inf))
self.assertEqual((b % (a - 2)).bounds(), (-np.inf, np.inf))

# floordiv
self.assertEqual(((a + 4) // 2).bounds(), (2, np.PINF))
self.assertEqual(((a + 4) // -2).bounds(), (np.NINF, -3))
self.assertEqual(((a + 5) // 2).bounds(), (3, np.PINF))
self.assertEqual(((a + 5) // -2).bounds(), (np.NINF, -3))
self.assertEqual(((a + 4) // 2).bounds(), (2, np.inf))
self.assertEqual(((a + 4) // -2).bounds(), (-np.inf, -3))
self.assertEqual(((a + 5) // 2).bounds(), (3, np.inf))
self.assertEqual(((a + 5) // -2).bounds(), (-np.inf, -3))
self.assertEqual((11 // (a + 1)).bounds(), (0, 5))
self.assertEqual((-11 // (a + 1)).bounds(), (-6, -1))
self.assertEqual((-11 // (- a)).bounds(), (0, 11)) # finite negative dividend, infinite divisor
self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.PINF))
self.assertEqual((-b // (a + 1)).bounds(), (np.NINF, -1))
self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.inf))
self.assertEqual((-b // (a + 1)).bounds(), (-np.inf, -1))

# Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2
# and then evaluate them for a = 1, 5, 10000
Expand All @@ -301,14 +301,14 @@ def test_poly_bounds(self):
# Bounds involving mod and floordiv
self.assertEqual((5 - a % 5).bounds(), (1, 5))
self.assertEqual((-5 - a % (-5)).bounds(), (-5, -1))
self.assertEqual((a - 5 % a).bounds(), (1, np.PINF))
self.assertEqual((a - 5 % a).bounds(), (1, np.PINF))
self.assertEqual((3 * (a + b) - 5 % (3 * (a + b))).bounds(), (1, np.PINF))
self.assertEqual((- a + (b - 5) % a).bounds(), (np.NINF, -1))
self.assertEqual((a - 5 % a).bounds(), (1, np.inf))
self.assertEqual((a - 5 % a).bounds(), (1, np.inf))
self.assertEqual((3 * (a + b) - 5 % (3 * (a + b))).bounds(), (1, np.inf))
self.assertEqual((- a + (b - 5) % a).bounds(), (-np.inf, -1))

# non_negative
self.assertEqual(core.non_negative_dim(a).bounds(), (1, np.PINF))
self.assertEqual(core.non_negative_dim(a - 5).bounds(), (0, np.PINF))
self.assertEqual(core.non_negative_dim(a).bounds(), (1, np.inf))
self.assertEqual(core.non_negative_dim(a - 5).bounds(), (0, np.inf))
self.assertEqual(core.non_negative_dim(15 - a).bounds(), (0, 14))
self.assertEqual((core.non_negative_dim(15 - a) // 3).bounds(), (0, 4))

Expand Down
6 changes: 3 additions & 3 deletions jax/numpy/__init__.py
Expand Up @@ -22,9 +22,9 @@

from jax._src.numpy.lax_numpy import (
ComplexWarning as ComplexWarning,
NINF as NINF,
NZERO as NZERO,
PZERO as PZERO,
NINF as NINF, # TODO: removed in Numpy 1.26
NZERO as NZERO, # TODO: removed in Numpy 1.26
PZERO as PZERO, # TODO: removed in Numpy 1.26
allclose as allclose,
angle as angle,
append as append,
Expand Down

0 comments on commit 1fedf04

Please sign in to comment.