-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move segment-specific utils to a separate file.
PiperOrigin-RevId: 526586948
- Loading branch information
Showing
5 changed files
with
255 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright 2023 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utilities for segmented functionality.""" | ||
|
||
from typing import Optional, Union | ||
|
||
import jax.numpy as jnp | ||
|
||
from rax._src.types import Array | ||
|
||
|
||
def same_segment_mask(segments: Array) -> Array: | ||
"""Returns an array indicating whether a pair is in the same segment.""" | ||
return jnp.expand_dims(segments, -1) == jnp.expand_dims(segments, axis=-2) | ||
|
||
|
||
def segment_sum( | ||
a: Array, segments: Array, where: Optional[Array] = None | ||
) -> Array: | ||
"""Returns segment sum.""" | ||
if where is not None: | ||
where = jnp.expand_dims(where, -1) & jnp.expand_dims(where, -2) | ||
return jnp.sum( | ||
jnp.expand_dims(a, -2) * jnp.int32(same_segment_mask(segments)), | ||
axis=-1, | ||
where=where, | ||
) | ||
|
||
|
||
def segment_max( | ||
a: Array, | ||
segments: Array, | ||
where: Optional[Array] = None, | ||
initial: Optional[Union[float, int]] = None, | ||
) -> Array: | ||
"""Returns segment max.""" | ||
mask = same_segment_mask(segments) | ||
if where is not None: | ||
mask &= jnp.expand_dims(where, -1) & jnp.expand_dims(where, -2) | ||
initial = jnp.min(a) if initial is None else initial | ||
return jnp.max( | ||
jnp.broadcast_to(jnp.expand_dims(a, -2), mask.shape), | ||
axis=-1, | ||
where=mask, | ||
initial=initial | ||
) | ||
|
||
|
||
def in_segment_indices(segments: Array) -> Array: | ||
"""Returns 0-based indices per segment. | ||
For example: segments = [0, 0, 0, 1, 2, 2], then the in-segment indices are | ||
[0, 1, 2 | 0 | 0, 1], where we use "|" to mark the boundaries of the segments. | ||
Returns [0, 1, 2, 0, 0, 1] for segments [0, 0, 0, 1, 2, 2]. | ||
Args: | ||
segments: A :class:`jax.numpy.ndarray` to indicate segments of items that | ||
should be grouped together. Like ``[0, 0, 1, 0, 2]``. The segments may or | ||
may not be sorted. | ||
Returns: | ||
An Array with 0-based indices per segment. | ||
""" | ||
same_segments = jnp.int32(same_segment_mask(segments)) | ||
lower_triangle = jnp.tril(jnp.ones_like(same_segments)) | ||
return jnp.sum(same_segments * lower_triangle, axis=-1) - 1 | ||
|
||
|
||
def first_item_segment_mask( | ||
segments: Array, where: Optional[Array] = None | ||
) -> Array: | ||
"""Constructs a mask that selects the first item per segment. | ||
Args: | ||
segments: A :class:`jax.numpy.ndarray` to indicate segments of items that | ||
should be grouped together. Like ``[0, 0, 1, 0, 2]``. The segments may or | ||
may not be sorted. | ||
where: An optional :class:`jax.numpy.ndarray` to indicate invalid items. | ||
Returns: | ||
A :class:`jax.numpy.ndarray` of the same shape as ``segments`` that selects | ||
the first valid item in each segment. | ||
""" | ||
# Construct a same-segment mask. | ||
mask = same_segment_mask(segments) | ||
|
||
# Mask out invalid items. | ||
if where is not None: | ||
mask = mask & (jnp.expand_dims(where, -1) & jnp.expand_dims(where, -2)) | ||
|
||
# Remove duplicated columns in the mask so only the first item for each | ||
# segment appears in the result. | ||
mask = mask & (jnp.cumsum(mask, axis=-1) == 1) | ||
|
||
# Collapse mask to original `segments` shape, so we get a mask that selects | ||
# exactly the first item per segment. | ||
return jnp.any(mask, axis=-2) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2023 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# pytype: skip-file | ||
"""Tests for rax._src.segment_utils.""" | ||
|
||
import doctest | ||
import functools | ||
from absl.testing import absltest | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
import rax | ||
from rax._src import segment_utils | ||
|
||
|
||
class SegmentsTest(absltest.TestCase): | ||
|
||
def test_same_segment_mask(self): | ||
segments = jnp.asarray([0, 0, 1]) | ||
expected = jnp.asarray([[1, 1, 0], [1, 1, 0], [0, 0, 1]]) | ||
actual = jnp.int32(segment_utils.same_segment_mask(segments)) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_segment_sum(self): | ||
scores = jnp.asarray([1.0, 2.0, 4.0]) | ||
segments = jnp.asarray([0, 0, 1]) | ||
expected = jnp.asarray([3.0, 3.0, 4.0]) | ||
actual = segment_utils.segment_sum(scores, segments) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_segment_max(self): | ||
scores = jnp.array([1.0, 2.0, 4.0, -5.0, -5.5, -4.5]) | ||
segments = jnp.array([0, 0, 1, 2, 2, 2]) | ||
expected = jnp.array([2.0, 2.0, 4.0, -4.5, -4.5, -4.5]) | ||
actual = segment_utils.segment_max(scores, segments) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_segment_max_with_initial(self): | ||
scores = jnp.array([1.0, 2.0, 4.0, 1.0, 2.0, 3.0]) | ||
segments = jnp.array([0, 0, 1, 2, 2, 2]) | ||
expected = jnp.array([2.5, 2.5, 4.0, 3.0, 3.0, 3.0]) | ||
actual = segment_utils.segment_max(scores, segments, initial=2.5) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_segment_max_with_where(self): | ||
scores = jnp.array([1.0, 2.0, 4.0, 1.0, 2.0, 3.0]) | ||
segments = jnp.array([0, 0, 1, 2, 2, 2]) | ||
mask = jnp.array([1, 0, 1, 1, 1, 0]) | ||
actual = segment_utils.segment_max(scores, segments, where=mask) | ||
# Only non-masked entries have well-defined behavior under max, so we only | ||
# check those. | ||
np.testing.assert_equal(actual[0], jnp.array(1.0)) | ||
np.testing.assert_equal(actual[2], jnp.array(4.0)) | ||
np.testing.assert_equal(actual[3], jnp.array(2.0)) | ||
np.testing.assert_equal(actual[4], jnp.array(2.0)) | ||
|
||
def test_in_segment_indices(self): | ||
segments = jnp.asarray([0, 0, 0, 1, 2, 2]) | ||
expected = jnp.asarray([0, 1, 2, 0, 0, 1]) | ||
actual = segment_utils.in_segment_indices(segments) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_in_segment_indices_unordered(self): | ||
segments = jnp.asarray([0, 0, 1, 0, 2, 2]) | ||
expected = jnp.asarray([0, 1, 0, 2, 0, 1]) | ||
actual = segment_utils.in_segment_indices(segments) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_first_item_segment_mask(self): | ||
segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 1, 1, 3, 3, 3]) | ||
expected = jnp.array([1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], dtype=jnp.bool_) | ||
actual = segment_utils.first_item_segment_mask(segments) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
def test_first_item_segment_mask_with_where(self): | ||
segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 1, 1, 3, 3, 3]) | ||
where = jnp.array([1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0], dtype=jnp.bool_) | ||
expected = jnp.array([1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0], dtype=jnp.bool_) | ||
actual = segment_utils.first_item_segment_mask(segments, where=where) | ||
np.testing.assert_array_equal(actual, expected) | ||
|
||
|
||
def load_tests(loader, tests, ignore): | ||
del loader, ignore # Unused. | ||
tests.addTests( | ||
doctest.DocTestSuite( | ||
segment_utils, | ||
globs={ | ||
"functools": functools, | ||
"jax": jax, | ||
"jnp": jnp, | ||
"rax": rax, | ||
"segment_utils": segment_utils, | ||
}, | ||
) | ||
) | ||
return tests | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
Oops, something went wrong.