# Pairwise & Embedding Similarity

This tutorial covers two cosine-similarity APIs in `braintools.metric`:

- Aligned embeddings:
  - `bt.metric.cosine_similarity(predictions, targets)` → per-pair similarity (…,)
  - `bt.metric.cosine_distance(predictions, targets)`   → 1 − similarity (…,)
- Pairwise similarity matrix:
  - `pairwise_cosine_similarity(X, Y=None)` → (n, m) matrix (X vs Y) or (n, n) if Y is None

Note: both functions are named `cosine_similarity` internally; in the public namespace, the
aligned version (`predictions, targets`) is bound to `bt.metric.cosine_similarity`. To access
the pairwise (matrix) variant, import it explicitly as shown below.

In [1]:
import jax.numpy as jnp
import braintools as bt
# Import the pairwise (matrix) version explicitly and alias it
from braintools.metric._pariwise import cosine_similarity as pairwise_cosine_similarity

## 1) Aligned embeddings: similarity and distance

Use these when you have matched pairs `(prediction_i, target_i)` and want per-pair scores.
These are scale-invariant and return values in [−1, 1]. Distance is `1 - similarity`.

In [2]:
pred = jnp.array([[1.0, 0.0, 0.0],
                   [0.0, 1.0, 0.0],
                   [1.0, 1.0, 0.0]])
targ = jnp.array([[1.0, 0.0, 0.0],
                   [1.0, 0.0, 0.0],
                   [0.0, 1.0, 0.0]])
sim = bt.metric.cosine_similarity(pred, targ)
dist = bt.metric.cosine_distance(pred, targ)
print('similarity:', sim)
print('distance  :', dist)

similarity: [1.         0.         0.70710677]
distance  : [0.         1.         0.29289323]


Tips
- Avoid zero vectors; if necessary, pass a small `epsilon` to `cosine_distance`.
- For batch aggregation, reduce over the last axis when needed before loss computation.

## 2) Pairwise similarity matrix (X vs Y)

Use this to compute all-pairs similarities between two sets of embeddings.
For `X: (n, d)`, `Y: (m, d)`, the result is `(n, m)`. With `Y=None`, returns `(n, n)` similarities within `X`.

In [3]:
X = jnp.array([[1.0, 0.0, 0.0],
               [0.0, 1.0, 0.0],
               [1.0, 1.0, 0.0]])
Y = jnp.array([[1.0, 1.0, 1.0],
               [0.0, 0.0, 1.0]])
S = pairwise_cosine_similarity(X, Y)
print('pairwise shape:', S.shape)
print(S)
# Within-set similarities (X vs X)
S_xx = pairwise_cosine_similarity(X)
print('within shape:', S_xx.shape)

pairwise shape: (3, 2)
[[0.57735026 0.        ]
 [0.57735026 0.        ]
 [0.8164966  0.        ]]
within shape: (3, 3)


Performance notes
- Pairwise matrices can be large: `(n, m)` memory scales linearly in both dimensions.
- For very large sets, consider batching queries or candidates to keep memory under control.
- JIT-compile hot paths with static shapes when possible.

## 3) Simple retrieval example (top‑k)

Compute similarities between `queries` and `items`, then take top‑k indices.

In [4]:
queries = jnp.array([[1.0, 0.0, 0.0],
                      [0.0, 1.0, 1.0]])
items   = jnp.array([[1.0, 0.0, 0.0],
                      [0.0, 1.0, 0.0],
                      [0.0, 1.0, 1.0]])
S_qi = pairwise_cosine_similarity(queries, items)  # (n_query, n_item)
# Top‑k via argsort (descending)
topk = 2
topk_idx = jnp.argsort(-S_qi, axis=1)[:, :topk]
print('top‑k indices per query:', topk_idx)
print('top‑k sims per query  :', jnp.take_along_axis(S_qi, topk_idx, axis=1))

top‑k indices per query: [[0 1]
 [2 1]]
top‑k sims per query  : [[1.         0.        ]
 [1.0000001  0.70710677]]


## 4) Choosing the right API

- Use `bt.metric.cosine_similarity / cosine_distance` for aligned pairs (same shape).
- Use `pairwise_cosine_similarity` to build `(n, m)` similarity matrices for retrieval/matching.
- Normalize inputs if needed; cosine metrics compare directions, not magnitudes.