Skip to content

Commit

Permalink
[rax] refactor the lambda weight to reduce the boilerplate.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605008082
  • Loading branch information
xuanhuiwang authored and Rax Developers committed Feb 7, 2024
1 parent 00922cb commit 4ae8a74
Showing 1 changed file with 106 additions and 88 deletions.
194 changes: 106 additions & 88 deletions rax/_src/lambdaweights.py
Expand Up @@ -85,6 +85,59 @@ def labeldiff_lambdaweight(
return results


def _generic_dcg_lambdaweight(
scores: Array,
labels: Array,
rank_pair_discount_fn: Callable[..., Array],
*,
where: Optional[Array] = None,
segments: Optional[Array] = None,
weights: Optional[Array] = None,
topn: Optional[int] = None,
normalize: bool = False,
gain_fn: Callable[[Array], Array] = metrics.default_gain_fn,
discount_fn: Callable[[Array], Array] = metrics.default_discount_fn
) -> Array:
r"""Generic DCG lambdaweights with customized rank pair discount fn."""

ranks = utils.ranks(scores, where=where, segments=segments)
gains = gain_fn(labels)
if weights is not None:
gains *= weights

if normalize:
ideal_dcg = metrics.dcg_metric(
gains,
labels,
where=where,
segments=segments,
topn=topn,
weights=weights,
gain_fn=gain_fn,
discount_fn=discount_fn,
reduce_fn=None,
)
ideal_dcg = jnp.where(ideal_dcg == 0.0, 1.0, ideal_dcg)
if segments is None:
ideal_dcg = jnp.expand_dims(ideal_dcg, -1)
gains /= ideal_dcg

gains_abs_diffs = jnp.abs(utils.compute_pairs(gains, operator.sub))

if where is not None:
valid_pairs = utils.compute_pairs(where, operator.and_)
else:
valid_pairs = jnp.ones_like(gains_abs_diffs, dtype=jnp.bool_)

pair_discount = rank_pair_discount_fn(discount_fn, ranks, valid_pairs, topn)

results = pair_discount * gains_abs_diffs
if segments is None:
return results
else:
return jnp.where(utils.compute_pairs(segments, operator.eq), results, 0.0)


def dcg_lambdaweight(
scores: Array,
labels: Array,
Expand Down Expand Up @@ -127,52 +180,34 @@ def dcg_lambdaweight(
Returns:
DCG lambdaweights.
"""
ranks = utils.ranks(scores, where=where, segments=segments)
gains = gain_fn(labels)
if weights is not None:
gains *= weights

if normalize:
ideal_dcg = metrics.dcg_metric(
gains,
labels,
where=where,
segments=segments,
topn=topn,
weights=weights,
gain_fn=gain_fn,
discount_fn=discount_fn,
reduce_fn=None,
)
ideal_dcg = jnp.where(ideal_dcg == 0.0, 1.0, ideal_dcg)
if segments is None:
ideal_dcg = jnp.expand_dims(ideal_dcg, -1)
gains /= ideal_dcg

gains_abs_diffs = jnp.abs(utils.compute_pairs(gains, operator.sub))
def _rank_pair_discount(discount_fn, ranks, valid_pairs, topn):
discounts = discount_fn(ranks)

if where is not None:
valid_pairs = utils.compute_pairs(where, operator.and_)
else:
valid_pairs = jnp.ones_like(gains_abs_diffs, dtype=jnp.bool_)

discounts = discount_fn(ranks)
if topn is not None:
discounts = jnp.where(ranks <= topn, discounts, 0.0)

if topn is not None:
discounts = jnp.where(ranks <= topn, discounts, 0.0)
discounts_abs_diffs = jnp.abs(utils.compute_pairs(discounts, operator.sub))
discounts_abs_diffs = jnp.where(valid_pairs, discounts_abs_diffs, 0.0)

discounts_abs_diffs = jnp.abs(utils.compute_pairs(discounts, operator.sub))
discounts_abs_diffs = jnp.where(valid_pairs, discounts_abs_diffs, 0.0)
# Scale up the lambdaweights by the constant list size to avoid too small
# values.
weight_scalar = ranks.shape[-1]

# Scale up the lambdaweights by the constant list size to avoid too small
# values.
weight_scalar = labels.shape[-1]
return discounts_abs_diffs * weight_scalar

results = discounts_abs_diffs * gains_abs_diffs * weight_scalar
if segments is None:
return results
else:
return jnp.where(utils.compute_pairs(segments, operator.eq), results, 0.0)
return _generic_dcg_lambdaweight(
scores,
labels,
_rank_pair_discount,
where=where,
segments=segments,
weights=weights,
topn=topn,
normalize=normalize,
gain_fn=gain_fn,
discount_fn=discount_fn,
)


def dcg2_lambdaweight(
Expand Down Expand Up @@ -218,54 +253,37 @@ def dcg2_lambdaweight(
Returns:
DCG v2 ("lambdaloss") lambdaweights.
"""
ranks = utils.ranks(scores, where=where, segments=segments)
gains = gain_fn(labels)
if weights is not None:
gains *= weights

if normalize:
ideal_dcg = metrics.dcg_metric(
gains,
labels,
where=where,
topn=topn,
weights=weights,
gain_fn=gain_fn,
discount_fn=discount_fn,
reduce_fn=None,
)
ideal_dcg = jnp.where(ideal_dcg == 0.0, 1.0, ideal_dcg)
if segments is None:
ideal_dcg = jnp.expand_dims(ideal_dcg, -1)
gains /= ideal_dcg

gains_abs_diffs = jnp.abs(utils.compute_pairs(gains, operator.sub))
def _rank_pair_discount(discount_fn, ranks, valid_pairs, topn):
ranks_abs_diffs = jnp.abs(utils.compute_pairs(ranks, operator.sub))
ranks_max = utils.compute_pairs(ranks, jnp.maximum)

if where is not None:
valid_pairs = utils.compute_pairs(where, operator.and_)
else:
valid_pairs = jnp.ones_like(gains_abs_diffs, dtype=jnp.bool_)

ranks_abs_diffs = jnp.abs(utils.compute_pairs(ranks, operator.sub))
ranks_max = utils.compute_pairs(ranks, jnp.maximum)

discounts = jnp.abs(
discount_fn(ranks_abs_diffs) - discount_fn(ranks_abs_diffs + 1)
discounts = jnp.abs(
discount_fn(ranks_abs_diffs) - discount_fn(ranks_abs_diffs + 1)
)
discounts = jnp.where(ranks_abs_diffs != 0, discounts, 0.0)

if topn is not None:
topn_weights = 1.0 / (1.0 - discount_fn(ranks_max))
discounts *= jnp.where(ranks_max > topn, topn_weights, 1.0)

discounts = jnp.where(valid_pairs, discounts, 0.0)

# Scale up the lambdaweights by the constant list size to avoid too small
# values.
weight_scalar = ranks.shape[-1]

return discounts * weight_scalar

return _generic_dcg_lambdaweight(
scores,
labels,
_rank_pair_discount,
where=where,
segments=segments,
weights=weights,
topn=topn,
normalize=normalize,
gain_fn=gain_fn,
discount_fn=discount_fn,
)
discounts = jnp.where(ranks_abs_diffs != 0, discounts, 0.0)

if topn is not None:
topn_weights = 1.0 / (1.0 - discount_fn(ranks_max))
discounts *= jnp.where(ranks_max > topn, topn_weights, 1.0)

discounts = jnp.where(valid_pairs, discounts, 0.0)

# Scale up the lambdaweights by the constant list size to avoid too small
# values.
weight_scalar = labels.shape[-1]

results = discounts * gains_abs_diffs * weight_scalar
if segments is None:
return results
else:
return jnp.where(utils.compute_pairs(segments, operator.eq), results, 0.0)

0 comments on commit 4ae8a74

Please sign in to comment.