Skip to content

Commit

Permalink
Refactor segmented lambdaweights unit tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526051493
  • Loading branch information
rjagerman authored and Rax Developers committed Apr 21, 2023
1 parent f3a4236 commit a117acd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
16 changes: 11 additions & 5 deletions rax/_src/lambdaweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ def labeldiff_lambdaweight(
Returns:
Absolute label difference lambdaweights.
"""
del scores, where, weights # Unused.
del scores, weights # Unused.

results = jnp.abs(utils.compute_pairs(labels, operator.sub))
if segments is None:
return results
else:
return jnp.where(utils.compute_pairs(segments, operator.eq), results, 0.0)

if where is not None:
results = jnp.where(utils.compute_pairs(where, operator.eq), results, 0.0)

if segments is not None:
results = jnp.where(
utils.compute_pairs(segments, operator.eq), results, 0.0
)

return results


def dcg_lambdaweight(
Expand Down
35 changes: 27 additions & 8 deletions rax/_src/lambdaweights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_computes_normalized_lambdaweights(self, lambdaweight_fn, normalizer):
],
]
}])
def test_lambdaweights_with_batchdim_or_segments(
def test_lambdaweights_with_batchdim(
self, lambdaweight_fn, expected
):
scores = jnp.array([[0.0, 1.0, 2.0], [0.5, 1.5, 1.0]])
Expand All @@ -172,14 +172,33 @@ def test_lambdaweights_with_batchdim_or_segments(
result = lambdaweight_fn(scores, labels)
np.testing.assert_allclose(result, expected, rtol=1e-5)

segmented_scores = jnp.array([0.0, 1.0, 2.0])
segmented_labels = jnp.array([0.0, 2.0, 1.0])
segments = jnp.array([1, 1, 1])
@parameterized.parameters([
lambdaweights.labeldiff_lambdaweight,
lambdaweights.dcg_lambdaweight,
lambdaweights.dcg2_lambdaweight,
])
def test_lambdaweights_with_segments(self, lambdaweight_fn):
scores = jnp.array([
[0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.5, 1.5, 1.0],
])
labels = jnp.array([
[0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
])
mask = jnp.array(
[[1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1]], dtype=jnp.bool_
)
nonseg_result = lambdaweight_fn(scores, labels, where=mask)

seg_scores = jnp.array([0.0, 1.0, 2.0, 3.0, 0.5, 1.5, 1.0])
seg_labels = jnp.array([0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1.0])
segments = jnp.array([1, 1, 1, 1, 2, 2, 2])
seg_result = lambdaweight_fn(seg_scores, seg_labels, segments=segments)

segmented_result = lambdaweight_fn(
segmented_scores, segmented_labels, segments=segments
np.testing.assert_allclose(
jnp.sum(nonseg_result, axis=0), seg_result, rtol=1e-5
)
np.testing.assert_allclose(segmented_result, expected[0], rtol=1e-5)

@parameterized.parameters([
lambdaweights.labeldiff_lambdaweight,
Expand All @@ -197,7 +216,7 @@ def test_lambdaweights_with_empty_list(self, lambdaweight_fn):

@parameterized.parameters([{
"lambdaweight_fn": lambdaweights.labeldiff_lambdaweight,
"expected": [0.0, 1.0, 0.3, 1.0, 0.0, 0.7, 0.3, 0.7, 0.0]
"expected": [0.0, 0.0, 0.3, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0]
}, {
"lambdaweight_fn":
lambdaweights.dcg_lambdaweight,
Expand Down

0 comments on commit a117acd

Please sign in to comment.