Skip to content

Commit

Permalink
Fix pytype and clean up types across codebase.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607680244
  • Loading branch information
rjagerman authored and Rax Developers committed Feb 16, 2024
1 parent 99ce1c4 commit 13208fe
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 46 deletions.
4 changes: 3 additions & 1 deletion rax/_src/lambdaweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

import jax.numpy as jnp
from rax._src import metrics
from rax._src import types
from rax._src import utils
from rax._src.types import Array

Array = types.Array


def labeldiff_lambdaweight(
Expand Down
47 changes: 24 additions & 23 deletions rax/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,21 @@
"""

import operator
from typing import Callable, Optional, Tuple
from typing import Callable, Optional

import jax
import jax.numpy as jnp

from rax._src import metrics
from rax._src import segment_utils
from rax._src import types
from rax._src import utils
from rax._src.types import Array
from rax._src.types import LambdaweightFn
from rax._src.types import ReduceFn

Array = types.Array
LambdaweightFn = types.LambdaweightFn
ReduceFn = types.ReduceFn


def softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def softmax_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -143,7 +144,7 @@ def softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def poly1_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def poly1_softmax_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -241,7 +242,7 @@ def poly1_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def unique_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def unique_softmax_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -340,7 +341,7 @@ def unique_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def listmle_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def listmle_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -422,11 +423,11 @@ def listmle_loss( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def pairwise_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pairwise_loss(
scores: Array,
labels: Array,
*,
pair_loss_fn: Callable[[Array, Array], Tuple[Array, Array]],
pair_loss_fn: Callable[[Array, Array], tuple[Array, Array]],
lambdaweight_fn: Optional[LambdaweightFn] = None,
where: Optional[Array] = None,
segments: Optional[Array] = None,
Expand Down Expand Up @@ -489,7 +490,7 @@ def pairwise_loss( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(pair_losses, where=valid_pairs, reduce_fn=reduce_fn)


def pairwise_hinge_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pairwise_hinge_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -530,7 +531,7 @@ def pairwise_hinge_loss( # pytype: disable=annotation-type-mismatch # jnp-type

def _hinge_loss(
scores_diff: Array, labels_diff: Array
) -> Tuple[Array, Array]:
) -> tuple[Array, Array]:
return jax.nn.relu(1.0 - scores_diff), labels_diff > 0

return pairwise_loss(
Expand All @@ -545,7 +546,7 @@ def _hinge_loss(
)


def pairwise_logistic_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pairwise_logistic_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -586,7 +587,7 @@ def pairwise_logistic_loss( # pytype: disable=annotation-type-mismatch # jnp-t

def _logistic_loss(
scores_diff: Array, labels_diff: Array
) -> Tuple[Array, Array]:
) -> tuple[Array, Array]:
return (
jax.nn.relu(-scores_diff) + jnp.log1p(jnp.exp(-jnp.abs(scores_diff))),
labels_diff > 0,
Expand All @@ -604,7 +605,7 @@ def _logistic_loss(
)


def pairwise_soft_zero_one_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pairwise_soft_zero_one_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -645,7 +646,7 @@ def pairwise_soft_zero_one_loss( # pytype: disable=annotation-type-mismatch #

def _soft_zero_one_loss(
scores_diff: Array, labels_diff: Array
) -> Tuple[Array, Array]:
) -> tuple[Array, Array]:
return (
jnp.where(
scores_diff > 0,
Expand All @@ -667,7 +668,7 @@ def _soft_zero_one_loss(
)


def pointwise_sigmoid_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pointwise_sigmoid_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -729,7 +730,7 @@ def pointwise_sigmoid_loss( # pytype: disable=annotation-type-mismatch # jnp-t
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def pointwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pointwise_mse_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -778,7 +779,7 @@ def pointwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def pairwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pairwise_mse_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -817,7 +818,7 @@ def pairwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type
The pairwise mean squared error loss.
"""

def _mse_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]:
def _mse_loss(scores_diff: Array, labels_diff: Array) -> tuple[Array, Array]:
return (
jnp.square(scores_diff - labels_diff),
jnp.ones_like(labels_diff > 0),
Expand All @@ -835,7 +836,7 @@ def _mse_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]:
)


def pairwise_qr_loss( # pytype: disable=annotation-type-mismatch # jnp-type
def pairwise_qr_loss(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -886,7 +887,7 @@ def pairwise_qr_loss( # pytype: disable=annotation-type-mismatch # jnp-type
The pairwise quantile regression loss.
"""

def _qr_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]:
def _qr_loss(scores_diff: Array, labels_diff: Array) -> tuple[Array, Array]:
loss_1 = jax.nn.relu(labels_diff - scores_diff)
loss_2 = jax.nn.relu(scores_diff - labels_diff)
if squared:
Expand Down
25 changes: 13 additions & 12 deletions rax/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@
from typing import Callable, Optional

import jax.numpy as jnp

from rax._src import segment_utils
from rax._src import types
from rax._src import utils
from rax._src.types import Array
from rax._src.types import CutoffFn
from rax._src.types import RankFn
from rax._src.types import ReduceFn

Array = types.Array
CutoffFn = types.CutoffFn
RankFn = types.RankFn
ReduceFn = types.ReduceFn


def _retrieved_items(
Expand Down Expand Up @@ -142,7 +143,7 @@ def default_discount_fn(rank: Array) -> Array:
return 1.0 / jnp.log2(rank + 1)


def mrr_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def mrr_metric(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -242,7 +243,7 @@ def mrr_metric( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def recall_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def recall_metric(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -342,7 +343,7 @@ def recall_metric( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def precision_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def precision_metric(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -442,7 +443,7 @@ def precision_metric( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def ap_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def ap_metric(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -553,7 +554,7 @@ def ap_metric( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def opa_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def opa_metric(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -619,7 +620,7 @@ def opa_metric( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(per_list_opa, where=where, reduce_fn=reduce_fn)


def dcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def dcg_metric(
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -715,7 +716,7 @@ def dcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def ndcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type
def ndcg_metric(
scores: Array,
labels: Array,
*,
Expand Down
8 changes: 5 additions & 3 deletions rax/_src/t12n.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import jax
import jax.numpy as jnp

from rax._src import types
from rax._src import utils
from rax._src.types import Array
from rax._src.types import LossFn
from rax._src.types import MetricFn

Array = types.Array
LossFn = types.LossFn
MetricFn = types.MetricFn

# Type aliases for ranking loss and metric functions.
LossOrMetricFn = TypeVar("LossOrMetricFn", LossFn, MetricFn)
Expand Down
9 changes: 5 additions & 4 deletions rax/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
do **not** need to instantiate, subclass or extend them.
"""

from typing import Optional, Tuple, Union
from typing import Optional, Sequence, Union
import jax

# Protocol is a python 3.8+ feature. For older versions, we can use
Expand Down Expand Up @@ -88,17 +88,18 @@ class ReduceFn(Protocol):
def __call__(
self,
a: Array,
axis: Optional[Union[int, Sequence[int]]],
*,
where: Optional[Array],
axis: Optional[Union[int, Tuple[int, ...]]],
) -> Array:
"""Reduces an array across one or more dimensions.
Args:
a: The array to reduce.
where: An optional :class:`~jax.Array` of the same shape as ``a`` that
indicates which elements to include in the reduction.
axis: One or more axes to use for the reduction. If ``None`` this reduces
across all available axes.
where: An optional :class:`~jax.Array` of the same shape as ``a`` that
indicates which elements to include in the reduction.
Returns:
A :class:`~jax.Array` that represents the reduced result of ``a``
Expand Down
5 changes: 2 additions & 3 deletions rax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@

import functools
import inspect

from typing import Any, Callable, Optional, Sequence, TypeVar

import jax
import jax.numpy as jnp

from rax._src import segment_utils
from rax._src.types import Array
from rax._src import types

Array = types.Array
T = TypeVar("T")


Expand Down

0 comments on commit 13208fe

Please sign in to comment.