Skip to content

Commit

Permalink
[rax] add a light discount option for dcg2_lambdaweight
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604465293
  • Loading branch information
xuanhuiwang authored and Rax Developers committed Feb 7, 2024
1 parent 4ae8a74 commit ef3b2c8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
25 changes: 18 additions & 7 deletions rax/_src/lambdaweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def labeldiff_lambdaweight(
*,
where: Optional[Array] = None,
segments: Optional[Array] = None,
weights: Optional[Array] = None
weights: Optional[Array] = None,
) -> Array:
r"""Absolute label difference lambdaweights.
Expand Down Expand Up @@ -96,7 +96,7 @@ def _generic_dcg_lambdaweight(
topn: Optional[int] = None,
normalize: bool = False,
gain_fn: Callable[[Array], Array] = metrics.default_gain_fn,
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn,
) -> Array:
r"""Generic DCG lambdaweights with customized rank pair discount fn."""

Expand Down Expand Up @@ -148,7 +148,7 @@ def dcg_lambdaweight(
topn: Optional[int] = None,
normalize: bool = False,
gain_fn: Callable[[Array], Array] = metrics.default_gain_fn,
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn,
) -> Array:
r"""DCG lambdaweights.
Expand Down Expand Up @@ -220,7 +220,8 @@ def dcg2_lambdaweight(
topn: Optional[int] = None,
normalize: bool = False,
gain_fn: Callable[[Array], Array] = metrics.default_gain_fn,
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn,
light_discount: bool = False,
) -> Array:
r"""DCG v2 ("lambdaloss") lambdaweights.
Expand All @@ -231,6 +232,12 @@ def dcg2_lambdaweight(
|\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|) -
\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|+1)|
Or the following when ``light_discount`` is ``True``:
.. math::
\lambda_{ij}(s, y) = |\op{gain}(y_i) - \op{gain}(y_j)| \cdot
|\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|)|
Args:
scores: A ``[..., list_size]``-:class:`~jax.Array`, indicating the score of
each item.
Expand All @@ -249,6 +256,7 @@ def dcg2_lambdaweight(
normalize: Whether to use the normalized DCG formulation.
gain_fn: A function mapping labels to gain values.
discount_fn: A function mapping ranks to discount values.
light_discount: If ``True``, make the position discount light as above.
Returns:
DCG v2 ("lambdaloss") lambdaweights.
Expand All @@ -258,9 +266,12 @@ def _rank_pair_discount(discount_fn, ranks, valid_pairs, topn):
ranks_abs_diffs = jnp.abs(utils.compute_pairs(ranks, operator.sub))
ranks_max = utils.compute_pairs(ranks, jnp.maximum)

discounts = jnp.abs(
discount_fn(ranks_abs_diffs) - discount_fn(ranks_abs_diffs + 1)
)
if light_discount:
discounts = jnp.abs(discount_fn(ranks_abs_diffs))
else:
discounts = jnp.abs(
discount_fn(ranks_abs_diffs) - discount_fn(ranks_abs_diffs + 1)
)
discounts = jnp.where(ranks_abs_diffs != 0, discounts, 0.0)

if topn is not None:
Expand Down
17 changes: 17 additions & 0 deletions rax/_src/lambdaweights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Tests for rax._src.weights."""

import doctest
import functools
import math

from absl.testing import absltest
Expand Down Expand Up @@ -67,6 +68,22 @@ class LambdaweightsTest(parameterized.TestCase):
0.0,
],
},
{
"lambdaweight_fn": functools.partial(
lambdaweights.dcg2_lambdaweight, light_discount=True
),
"expected": [
0.0,
3.0 * abs(discount(1)) * abs(gain(0.0) - gain(1.0)),
3.0 * abs(discount(2)) * abs(gain(0.0) - gain(0.3)),
3.0 * abs(discount(1)) * abs(gain(1.0) - gain(0.0)),
0.0,
3.0 * abs(discount(1)) * abs(gain(1.0) - gain(0.3)),
3.0 * abs(discount(2)) * abs(gain(0.3) - gain(0.0)),
3.0 * abs(discount(1)) * abs(gain(0.3) - gain(1.0)),
0.0,
],
},
])
def test_computes_lambdaweights(self, lambdaweight_fn, expected):
scores = jnp.array([0.0, 1.0, 2.0])
Expand Down

0 comments on commit ef3b2c8

Please sign in to comment.