Skip to content

Commit

Permalink
Clean up docstrings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 529439175
  • Loading branch information
rjagerman authored and Rax Developers committed May 4, 2023
1 parent 90c72e4 commit d522ef6
Show file tree
Hide file tree
Showing 7 changed files with 407 additions and 411 deletions.
8 changes: 8 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@
bibtex_default_style = 'alpha'
bibtex_reference_style = 'author_year'

# -- Options for katex -------------------------------------------------------

katex_options = r'''{
macros: {
"\\II": "\\mathbb{I}\\left[#1\\right]",
"\\op": "\\operatorname{#1}",
}
}'''

# -- Intersphinx configuration -----------------------------------------------

Expand Down
84 changes: 41 additions & 43 deletions rax/_src/lambdaweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@ def labeldiff_lambdaweight(
\lambda_{ij}(s, y) = |y_i - y_j|
Args:
scores: A ``[..., list_size]``-:class:`~jax.numpy.ndarray`, indicating the
score of each item.
labels: A ``[..., list_size]``-:class:`~jax.numpy.ndarray`, indicating the
relevance label for each item.
where: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating which items are valid for computing the lambdaweights. Items
for which this is False will be ignored when computing the lambdaweights.
segments: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating segments within each list. The loss will only be computed on
scores: A ``[..., list_size]``-:class:`~jax.Array`, indicating the score of
each item.
labels: A ``[..., list_size]``-:class:`~jax.Array`, indicating the relevance
label for each item.
where: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
which items are valid for computing the lambdaweights. Items for which
this is False will be ignored when computing the lambdaweights.
segments: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
segments within each list. The lambdaweights will only be computed on
items that share the same segment.
weights: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating the weight for each item.
weights: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
the weight for each item.
Returns:
Absolute label difference lambdaweights.
Expand Down Expand Up @@ -102,23 +102,23 @@ def dcg_lambdaweight(
Definition :cite:p:`burges2006learning`:
.. math::
\lambda_{ij}(s, y) = |\operatorname{gain}(y_i) - \operatorname{gain}(y_j)|
\cdot |\operatorname{discount}(\operatorname{rank}(s_i)) -
\operatorname{discount}(\operatorname{rank}(s_j))|
\lambda_{ij}(s, y) = |\op{gain}(y_i) - \op{gain}(y_j)|
\cdot |\op{discount}(\op{rank}(s_i)) -
\op{discount}(\op{rank}(s_j))|
Args:
scores: A ``[..., list_size]``-:class:`~jax.numpy.ndarray`, indicating the
score of each item.
labels: A ``[..., list_size]``-:class:`~jax.numpy.ndarray`, indicating the
relevance label for each item.
where: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating which items are valid for computing the lambdaweights. Items
for which this is False will be ignored when computing the lambdaweights.
segments: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating segments within each list. The loss will only be computed on
scores: A ``[..., list_size]``-:class:`~jax.Array`, indicating the score of
each item.
labels: A ``[..., list_size]``-:class:`~jax.Array`, indicating the relevance
label for each item.
where: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
which items are valid for computing the lambdaweights. Items for which
this is False will be ignored when computing the lambdaweights.
segments: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
segments within each list. The lambdaweights will only be computed on
items that share the same segment.
weights: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating the weight for each item.
weights: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
the weight for each item.
topn: The topn cutoff. If ``None``, no cutoff is performed.
normalize: Whether to use the normalized DCG formulation.
gain_fn: A function mapping labels to gain values.
Expand Down Expand Up @@ -192,27 +192,25 @@ def dcg2_lambdaweight(
Definition :cite:p:`wang2018lambdaloss`:
.. math::
\lambda_{ij}(s, y) = |\operatorname{gain}(y_i) - \operatorname{gain}(y_j)|
\cdot |\operatorname{discount}(
|\operatorname{rank}(s_i) - \operatorname{rank}(s_j)|) -
\operatorname{discount}(
|\operatorname{rank}(s_i) - \operatorname{rank}(s_j)| + 1)|
\lambda_{ij}(s, y) = |\op{gain}(y_i) - \op{gain}(y_j)| \cdot
|\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|) -
\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|+1)|
Args:
scores: A ``[..., list_size]``-:class:`~jax.numpy.ndarray`, indicating the
score of each item.
labels: A ``[..., list_size]``-:class:`~jax.numpy.ndarray`, indicating the
relevance label for each item.
where: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating which items are valid for computing the lambdaweights. Items
for which this is False will be ignored when computing the lambdaweights.
segments: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating segments within each list. The loss will only be computed on
scores: A ``[..., list_size]``-:class:`~jax.Array`, indicating the score of
each item.
labels: A ``[..., list_size]``-:class:`~jax.Array`, indicating the relevance
label for each item.
where: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
which items are valid for computing the lambdaweights. Items for which
this is False will be ignored when computing the lambdaweights.
segments: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
segments within each list. The lambdaweights will only be computed on
items that share the same segment.
weights: An optional ``[..., list_size]``-:class:`~jax.numpy.ndarray`,
indicating the weight for each item.
topn: The topn cutoff. If ``None``, no cutoff is performed. Topn cutoff is
uses the method described in :cite:p:`jagerman2022optimizing`.
weights: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
the weight for each item.
topn: The topn cutoff. If ``None``, no cutoff is performed. Topn cutoff uses
the method described in :cite:p:`jagerman2022optimizing`.
normalize: Whether to use the normalized DCG formulation.
gain_fn: A function mapping labels to gain values.
discount_fn: A function mapping ranks to discount values.
Expand Down
Loading

0 comments on commit d522ef6

Please sign in to comment.