Skip to content

Commit

Permalink
segment_max: fix identity for boolean dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 15, 2022
1 parent 4d14899 commit 4f6f4e5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
9 changes: 6 additions & 3 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,22 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,




def _get_identity(op, dtype):
"""Get an appropriate identity for a given operation in a given dtype."""
if op is lax.scatter_add:
return 0
elif op is lax.scatter_mul:
return 1
elif op is lax.scatter_min:
if jnp.issubdtype(dtype, jnp.integer):
if dtype == dtypes.bool_:
return True
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).max
return float('inf')
elif op is lax.scatter_max:
if jnp.issubdtype(dtype, jnp.integer):
if dtype == dtypes.bool_:
return False
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).min
return -float('inf')
else:
Expand Down
43 changes: 43 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,49 @@ def fn(data, segment_ids):
self.assertAllClose(grad, np.array([0., 0.], np.float32))


@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
jtu.format_shape_dtype_string(shape, dtype),
reducer.__name__, num_segments, bucket_size),
"dtype": dtype, "shape": shape,
"reducer": reducer, "op": op, "identity": identity,
"num_segments": num_segments, "bucket_size": bucket_size}
for dtype in [np.bool_]
for shape in [(8,), (7, 4), (6, 4, 2)]
for bucket_size in [None, 2]
for num_segments in [None, 1, 3])
for reducer, op, identity in [
(ops.segment_min, np.minimum, True),
(ops.segment_max, np.maximum, False),
]))
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]

if np.issubdtype(dtype, np.integer):
if np.isposinf(identity):
identity = np.iinfo(dtype).max
elif np.isneginf(identity):
identity = np.iinfo(dtype).min

jnp_fun = lambda data, segment_ids: reducer(
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)

def np_fun(data, segment_ids):
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
out = np.full((size,) + shape[1:], identity, dtype)
for i, val in zip(segment_ids, data):
if 0 <= i < size:
out[i] = op(out[i], val).astype(dtype)
return out

self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
if num_segments is not None:
self._CompileAndCheck(jnp_fun, args_maker)


@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
Expand Down

0 comments on commit 4f6f4e5

Please sign in to comment.