Skip to content

Commit

Permalink
Move segment-specific utils to a separate file.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526586948
  • Loading branch information
rjagerman authored and Rax Developers committed Apr 24, 2023
1 parent 6a7f5b0 commit 6146a4f
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 175 deletions.
39 changes: 24 additions & 15 deletions rax/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

import jax.numpy as jnp

from rax._src import segment_utils
from rax._src import utils
from rax._src.types import Array
from rax._src.types import CutoffFn
Expand Down Expand Up @@ -212,7 +213,7 @@ def mrr_metric(

# Get the maximum reciprocal rank.
if segments is not None:
values = utils.segment_max(
values = segment_utils.segment_max(
relevant_items * retrieved_items * reciprocal_ranks,
segments,
where=where,
Expand All @@ -229,7 +230,7 @@ def mrr_metric(
# In the segmented case, values retain their list dimension. This constructs
# a mask so that only the first item per segment is used in reduce_fn.
if segments is not None:
where = utils.first_item_segment_mask(segments, where=where)
where = segment_utils.first_item_segment_mask(segments, where=where)

# Setup mask to ignore lists with only invalid items in reduce_fn.
elif where is not None:
Expand Down Expand Up @@ -310,10 +311,12 @@ def recall_metric(

# Compute number of retrieved+relevant items and relevant items.
if segments is not None:
n_retrieved_relevant = utils.segment_sum(
n_retrieved_relevant = segment_utils.segment_sum(
retrieved_items * relevant_items, segments, where=where
)
n_relevant = utils.segment_sum(relevant_items, segments, where=where)
n_relevant = segment_utils.segment_sum(
relevant_items, segments, where=where
)
else:
n_retrieved_relevant = jnp.sum(
retrieved_items * relevant_items, where=where, axis=-1
Expand All @@ -327,7 +330,7 @@ def recall_metric(
# In the segmented case, values retain their list dimension. This constructs
# a mask so that only the first item per segment is used in reduce_fn.
if segments is not None:
where = utils.first_item_segment_mask(segments, where=where)
where = segment_utils.first_item_segment_mask(segments, where=where)

# Setup mask to ignore lists with only invalid items in reduce_fn.
elif where is not None:
Expand Down Expand Up @@ -408,10 +411,12 @@ def precision_metric(

# Compute number of retrieved+relevant items and retrieved items.
if segments is not None:
n_retrieved_relevant = utils.segment_sum(
n_retrieved_relevant = segment_utils.segment_sum(
retrieved_items * relevant_items, segments, where=where
)
n_retrieved = utils.segment_sum(retrieved_items, segments, where=where)
n_retrieved = segment_utils.segment_sum(
retrieved_items, segments, where=where
)
else:
n_retrieved_relevant = jnp.sum(
retrieved_items * relevant_items, where=where, axis=-1
Expand All @@ -425,7 +430,7 @@ def precision_metric(
# In the segmented case, values retain their list dimension. This constructs
# a mask so that only the first item per segment is used in reduce_fn.
if segments is not None:
where = utils.first_item_segment_mask(segments, where=where)
where = segment_utils.first_item_segment_mask(segments, where=where)

# Setup mask to ignore lists with only invalid items in reduce_fn.
elif where is not None:
Expand Down Expand Up @@ -512,15 +517,19 @@ def ap_metric(
prec_at_k = ((ranks_i >= ranks_j) * relevant_i * relevant_j) / ranks_i

# Only include precision@k for retrieved items.
prec_mask = None if segments is None else utils.same_segment_mask(segments)
prec_mask = None
if segments is not None:
prec_mask = segment_utils.same_segment_mask(segments)
prec_at_k = jnp.sum(
prec_at_k * jnp.expand_dims(retrieved_items, -1), axis=-1, where=prec_mask
)

# Compute summed precision@k for each list and the number of relevant items.
if segments is not None:
sum_prec_at_k = utils.segment_sum(prec_at_k, segments, where=where)
n_relevant = utils.segment_sum(relevant_items, segments, where=where)
sum_prec_at_k = segment_utils.segment_sum(prec_at_k, segments, where=where)
n_relevant = segment_utils.segment_sum(
relevant_items, segments, where=where
)
else:
sum_prec_at_k = jnp.sum(prec_at_k, axis=-1)
n_relevant = jnp.sum(relevant_items, where=where, axis=-1)
Expand All @@ -532,7 +541,7 @@ def ap_metric(
# In the segmented case, values retain their list dimension. This constructs
# a mask so that only the first item per segment is used in reduce_fn.
if segments is not None:
where = utils.first_item_segment_mask(segments, where=where)
where = segment_utils.first_item_segment_mask(segments, where=where)

# Setup mask to ignore lists with only invalid items in reduce_fn.
elif where is not None:
Expand Down Expand Up @@ -621,7 +630,7 @@ def dcg_metric(

# Compute DCG.
if segments is not None:
values = utils.segment_sum(
values = segment_utils.segment_sum(
retrieved_items * gains * discounts, segments, where=where
)
else:
Expand All @@ -630,7 +639,7 @@ def dcg_metric(
# In the segmented case, values retain their list dimension. This constructs
# a mask so that only the first item per segment is used in reduce_fn.
if segments is not None:
where = utils.first_item_segment_mask(segments, where=where)
where = segment_utils.first_item_segment_mask(segments, where=where)

# Setup mask to ignore lists with only invalid items in reduce_fn.
elif where is not None:
Expand Down Expand Up @@ -733,7 +742,7 @@ def ndcg_metric(
# In the segmented case, values retain their list dimension. This constructs
# a mask so that only the first item per segment is used in reduce_fn.
if segments is not None:
where = utils.first_item_segment_mask(segments, where=where)
where = segment_utils.first_item_segment_mask(segments, where=where)

# Setup mask to ignore lists with only invalid items in reduce_fn.
elif where is not None:
Expand Down
111 changes: 111 additions & 0 deletions rax/_src/segment_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for segmented functionality."""

from typing import Optional, Union

import jax.numpy as jnp

from rax._src.types import Array


def same_segment_mask(segments: Array) -> Array:
"""Returns an array indicating whether a pair is in the same segment."""
return jnp.expand_dims(segments, -1) == jnp.expand_dims(segments, axis=-2)


def segment_sum(
a: Array, segments: Array, where: Optional[Array] = None
) -> Array:
"""Returns segment sum."""
if where is not None:
where = jnp.expand_dims(where, -1) & jnp.expand_dims(where, -2)
return jnp.sum(
jnp.expand_dims(a, -2) * jnp.int32(same_segment_mask(segments)),
axis=-1,
where=where,
)


def segment_max(
a: Array,
segments: Array,
where: Optional[Array] = None,
initial: Optional[Union[float, int]] = None,
) -> Array:
"""Returns segment max."""
mask = same_segment_mask(segments)
if where is not None:
mask &= jnp.expand_dims(where, -1) & jnp.expand_dims(where, -2)
initial = jnp.min(a) if initial is None else initial
return jnp.max(
jnp.broadcast_to(jnp.expand_dims(a, -2), mask.shape),
axis=-1,
where=mask,
initial=initial
)


def in_segment_indices(segments: Array) -> Array:
"""Returns 0-based indices per segment.
For example: segments = [0, 0, 0, 1, 2, 2], then the in-segment indices are
[0, 1, 2 | 0 | 0, 1], where we use "|" to mark the boundaries of the segments.
Returns [0, 1, 2, 0, 0, 1] for segments [0, 0, 0, 1, 2, 2].
Args:
segments: A :class:`jax.numpy.ndarray` to indicate segments of items that
should be grouped together. Like ``[0, 0, 1, 0, 2]``. The segments may or
may not be sorted.
Returns:
An Array with 0-based indices per segment.
"""
same_segments = jnp.int32(same_segment_mask(segments))
lower_triangle = jnp.tril(jnp.ones_like(same_segments))
return jnp.sum(same_segments * lower_triangle, axis=-1) - 1


def first_item_segment_mask(
segments: Array, where: Optional[Array] = None
) -> Array:
"""Constructs a mask that selects the first item per segment.
Args:
segments: A :class:`jax.numpy.ndarray` to indicate segments of items that
should be grouped together. Like ``[0, 0, 1, 0, 2]``. The segments may or
may not be sorted.
where: An optional :class:`jax.numpy.ndarray` to indicate invalid items.
Returns:
A :class:`jax.numpy.ndarray` of the same shape as ``segments`` that selects
the first valid item in each segment.
"""
# Construct a same-segment mask.
mask = same_segment_mask(segments)

# Mask out invalid items.
if where is not None:
mask = mask & (jnp.expand_dims(where, -1) & jnp.expand_dims(where, -2))

# Remove duplicated columns in the mask so only the first item for each
# segment appears in the result.
mask = mask & (jnp.cumsum(mask, axis=-1) == 1)

# Collapse mask to original `segments` shape, so we get a mask that selects
# exactly the first item per segment.
return jnp.any(mask, axis=-2)


114 changes: 114 additions & 0 deletions rax/_src/segment_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pytype: skip-file
"""Tests for rax._src.segment_utils."""

import doctest
import functools
from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np

import rax
from rax._src import segment_utils


class SegmentsTest(absltest.TestCase):

def test_same_segment_mask(self):
segments = jnp.asarray([0, 0, 1])
expected = jnp.asarray([[1, 1, 0], [1, 1, 0], [0, 0, 1]])
actual = jnp.int32(segment_utils.same_segment_mask(segments))
np.testing.assert_array_equal(actual, expected)

def test_segment_sum(self):
scores = jnp.asarray([1.0, 2.0, 4.0])
segments = jnp.asarray([0, 0, 1])
expected = jnp.asarray([3.0, 3.0, 4.0])
actual = segment_utils.segment_sum(scores, segments)
np.testing.assert_array_equal(actual, expected)

def test_segment_max(self):
scores = jnp.array([1.0, 2.0, 4.0, -5.0, -5.5, -4.5])
segments = jnp.array([0, 0, 1, 2, 2, 2])
expected = jnp.array([2.0, 2.0, 4.0, -4.5, -4.5, -4.5])
actual = segment_utils.segment_max(scores, segments)
np.testing.assert_array_equal(actual, expected)

def test_segment_max_with_initial(self):
scores = jnp.array([1.0, 2.0, 4.0, 1.0, 2.0, 3.0])
segments = jnp.array([0, 0, 1, 2, 2, 2])
expected = jnp.array([2.5, 2.5, 4.0, 3.0, 3.0, 3.0])
actual = segment_utils.segment_max(scores, segments, initial=2.5)
np.testing.assert_array_equal(actual, expected)

def test_segment_max_with_where(self):
scores = jnp.array([1.0, 2.0, 4.0, 1.0, 2.0, 3.0])
segments = jnp.array([0, 0, 1, 2, 2, 2])
mask = jnp.array([1, 0, 1, 1, 1, 0])
actual = segment_utils.segment_max(scores, segments, where=mask)
# Only non-masked entries have well-defined behavior under max, so we only
# check those.
np.testing.assert_equal(actual[0], jnp.array(1.0))
np.testing.assert_equal(actual[2], jnp.array(4.0))
np.testing.assert_equal(actual[3], jnp.array(2.0))
np.testing.assert_equal(actual[4], jnp.array(2.0))

def test_in_segment_indices(self):
segments = jnp.asarray([0, 0, 0, 1, 2, 2])
expected = jnp.asarray([0, 1, 2, 0, 0, 1])
actual = segment_utils.in_segment_indices(segments)
np.testing.assert_array_equal(actual, expected)

def test_in_segment_indices_unordered(self):
segments = jnp.asarray([0, 0, 1, 0, 2, 2])
expected = jnp.asarray([0, 1, 0, 2, 0, 1])
actual = segment_utils.in_segment_indices(segments)
np.testing.assert_array_equal(actual, expected)

def test_first_item_segment_mask(self):
segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 1, 1, 3, 3, 3])
expected = jnp.array([1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], dtype=jnp.bool_)
actual = segment_utils.first_item_segment_mask(segments)
np.testing.assert_array_equal(actual, expected)

def test_first_item_segment_mask_with_where(self):
segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 1, 1, 3, 3, 3])
where = jnp.array([1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0], dtype=jnp.bool_)
expected = jnp.array([1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0], dtype=jnp.bool_)
actual = segment_utils.first_item_segment_mask(segments, where=where)
np.testing.assert_array_equal(actual, expected)


def load_tests(loader, tests, ignore):
del loader, ignore # Unused.
tests.addTests(
doctest.DocTestSuite(
segment_utils,
globs={
"functools": functools,
"jax": jax,
"jnp": jnp,
"rax": rax,
"segment_utils": segment_utils,
},
)
)
return tests


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 6146a4f

Please sign in to comment.