# Ranking for Learning-to-Rank

This tutorial shows how to use the listwise ranking loss in `braintools.metric`
with masking and reduction options. It is suited for information retrieval,
recommendation, and other Learning-to-Rank tasks.

Covered API:
- `bt.metric.ranking_softmax_loss(logits, labels, *, where=None, weights=None, reduce_fn=jnp.mean)`
  - `where`: boolean mask for valid items (padding handling)
  - `weights`: per-item weights
  - `reduce_fn`: `jnp.mean`, `jnp.sum`, or `None` (unreduced)


In [1]:
import jax.numpy as jnp
import braintools as bt

## 1) Basic usage (single list)

`logits` are scores to be ranked; `labels` are non-negative relevances (e.g.,
binary relevance or graded). The loss operates on the last dimension.

In [2]:
# One list of 4 items
logits = jnp.array([2.0, 1.0, 0.5, 0.2])
labels = jnp.array([1.0, 0.0, 0.0, 0.0])  # item 0 is most relevant
loss = bt.metric.ranking_softmax_loss(logits, labels)
print(loss)  # scalar (default reduce_fn=jnp.mean)

0.5632142


## 2) Batched lists with masks (variable lengths)

Use `where` to ignore padded items. It must be a boolean array with the same
shape as logits and labels.

In [3]:
# Two lists, padded to length 5
logits = jnp.array([[2.0, 1.0, 0.5, -1.0,  0.0],
                     [0.8, 0.3, 1.2, -2.0, -1.0]])
labels = jnp.array([[1.0, 0.0, 0.0, 0.0, 0.0],
                     [0.0, 0.0, 1.0, 0.0, 0.0]])
# First list has 4 valid items; second has 3 valid items
where  = jnp.array([[ True,  True,  True,  True, False],
                    [ True,  True,  True, False, False]])

# Default reduce_fn=jnp.mean -> scalar over batch
loss_mean = bt.metric.ranking_softmax_loss(logits, labels, where=where)
print('Mean loss (scalar):', loss_mean)

# Unreduced per-list losses
loss_per_list = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)
print('Per-list loss:', loss_per_list)  # shape (batch,)

Mean loss (scalar): 0.6130267
Per-list loss: [0.49518192 0.73087144]


## 3) Reductions: sum vs mean vs none

- `reduce_fn=jnp.sum`: sum across the batch
- `reduce_fn=jnp.mean`: average across the batch (default)
- `reduce_fn=None`: return unreduced per-batch values

When there are no valid items (mask all-False) and inputs contain no NaN,
the mean reduction returns 0.0 to avoid NaNs.

In [4]:
sum_loss  = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.sum)
mean_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.mean)
none_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)
print('sum:',  sum_loss)
print('mean:', mean_loss)
print('none:', none_loss, ' sum(none)=', jnp.sum(none_loss))

sum: 1.2260534
mean: 0.6130267
none: [0.49518192 0.73087144]  sum(none)= 1.2260534


## 4) Per-item weights

Provide `weights` to emphasize specific items in lists. `weights` must match
the shape of `labels`/`logits` and is applied to the labels prior to the
softmax cross-entropy.

In [5]:
weights = jnp.array([[1.0, 0.5, 0.5, 1.0, 0.0],
                     [1.0, 1.0, 2.0, 0.0, 0.0]])
weighted_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, weights=weights, reduce_fn=None)
print('Weighted per-list loss:', weighted_loss)

Weighted per-list loss: [0.49518192 1.4617429 ]


## 5) Tips & Pitfalls

- Shapes: operate on the last dimension `(…, list_size)`; batch dims are leading.
- `where` must be boolean and broadcastable to `(…, list_size)`.
- `weights` must match the labels/logits shape.
- `reduce_fn=None` returns per-batch values; you can aggregate manually.
- If a list has no valid items (`where` all-False), mean reduction returns 0.0 (when inputs have no NaN).
- For large batches and lists, prefer JIT-compiling code paths that call this loss with static shapes.