Skip to content

Commit

Permalink
Refactor segmented lambdaweights unit tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526111119
  • Loading branch information
rjagerman authored and Rax Developers committed Apr 21, 2023
1 parent 462020d commit 6a7f5b0
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions rax/_src/lambdaweights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_computes_normalized_lambdaweights(self, lambdaweight_fn, normalizer):
],
]
}])
def test_lambdaweights_with_batchdim_or_segments(
def test_lambdaweights_with_batchdim(
self, lambdaweight_fn, expected
):
scores = jnp.array([[0.0, 1.0, 2.0], [0.5, 1.5, 1.0]])
Expand All @@ -172,14 +172,33 @@ def test_lambdaweights_with_batchdim_or_segments(
result = lambdaweight_fn(scores, labels)
np.testing.assert_allclose(result, expected, rtol=1e-5)

segmented_scores = jnp.array([0.0, 1.0, 2.0])
segmented_labels = jnp.array([0.0, 2.0, 1.0])
segments = jnp.array([1, 1, 1])
@parameterized.parameters([
lambdaweights.labeldiff_lambdaweight,
lambdaweights.dcg_lambdaweight,
lambdaweights.dcg2_lambdaweight,
])
def test_lambdaweights_with_segments(self, lambdaweight_fn):
scores = jnp.array([
[0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.5, 1.5, 1.0],
])
labels = jnp.array([
[0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
])
mask = jnp.array(
[[1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1]], dtype=jnp.bool_
)
nonseg_result = lambdaweight_fn(scores, labels, where=mask)

segmented_result = lambdaweight_fn(
segmented_scores, segmented_labels, segments=segments
seg_scores = jnp.array([0.0, 1.0, 2.0, 3.0, 0.5, 1.5, 1.0])
seg_labels = jnp.array([0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1.0])
segments = jnp.array([1, 1, 1, 1, 2, 2, 2])
seg_result = lambdaweight_fn(seg_scores, seg_labels, segments=segments)

np.testing.assert_allclose(
jnp.sum(nonseg_result, axis=0), seg_result, rtol=1e-5
)
np.testing.assert_allclose(segmented_result, expected[0], rtol=1e-5)

@parameterized.parameters([
lambdaweights.labeldiff_lambdaweight,
Expand Down

0 comments on commit 6a7f5b0

Please sign in to comment.