# Regression Losses

This tutorial shows how to use core regression losses in `braintools.metric` and when to choose each:

- L1 / MAE: `absolute_error`, `l1_loss`
- L2 / MSE: `squared_error`, `l2_loss`
- Robust: `huber_loss`, `log_cosh`
- Embeddings: `cosine_distance` (and `cosine_similarity`)

Most functions support reductions like `'none'|'mean'|'sum'` and optional `axis` for per-sample aggregation.

In [1]:
import jax.numpy as jnp
import braintools as bt

## Setup: sample predictions and targets

We'll use simple arrays for clarity; in practice these are model outputs and labels.

In [2]:
y_pred = jnp.array([[1.0, 2.0, 3.0],
                     [2.0, 2.5, 2.0]])
y_true = jnp.array([[1.1, 1.9, 3.2],
                     [2.0, 2.0, 2.0]])
y_outlier = jnp.array([[1.0, 2.0, 10.0],
                       [2.0, 2.5, -5.0]])  # to show robustness

## L1 loss (Mean Absolute Error)

Use L1 when robustness to outliers is important.

In [3]:
# Elementwise absolute error, then mean over last axis (per-sample MAE)
mae_per_sample = bt.metric.absolute_error(y_pred, y_true, axis=-1, reduction='mean')
print('MAE per sample:', mae_per_sample)

# Direct L1 loss API (commonly returns mean by default)
l1 = bt.metric.l1_loss(y_pred, y_true)
print('l1_loss (mean):', l1)

# Outlier comparison
print('MAE with outlier:', bt.metric.absolute_error(y_outlier, y_true, axis=-1, reduction='mean'))

MAE per sample: [0.13333337 0.16666667]
l1_loss (mean): 0.9000001
MAE with outlier: [2.3333335 2.5      ]


## L2 loss (Mean Squared Error)

Use L2 when larger errors should be penalized more heavily.

In [4]:
# Squared error mean over last axis (per-sample MSE)
mse_per_sample = bt.metric.squared_error(y_pred, y_true, axis=-1, reduction='mean')
print('MSE per sample:', mse_per_sample)

# Direct L2 loss API
l2 = bt.metric.l2_loss(y_pred, y_true)
print('l2_loss (mean):', l2)

# Outlier comparison
print('MSE with outlier:', bt.metric.squared_error(y_outlier, y_true, axis=-1, reduction='mean'))

MSE per sample: [0.02000001 0.08333334]
l2_loss (mean): [[0.005      0.005      0.02000001]
 [0.         0.125      0.        ]]
MSE with outlier: [15.420001 16.416668]


## Huber loss (robust L2)

Huber behaves like L2 near zero and L1 for large residuals; set `delta` to tune the transition.

In [5]:
huber = bt.metric.huber_loss(y_pred, y_true, delta=1.0)
huber_outlier = bt.metric.huber_loss(y_outlier, y_true, delta=1.0)
print('Huber (mean):', huber)
print('Huber with outlier (mean):', huber_outlier)

Huber (mean): [[0.005      0.005      0.02000001]
 [0.         0.125      0.        ]]
Huber with outlier (mean): [[5.000002e-03 5.000002e-03 6.300000e+00]
 [0.000000e+00 1.250000e-01 6.500000e+00]]


## log-cosh (smooth robust loss)

`log_cosh` is a smooth approximation to L1 that is less sensitive than L2 to outliers.

In [6]:
lc = bt.metric.log_cosh(y_pred - y_true)
lc_outlier = bt.metric.log_cosh(y_outlier - y_true)
print('log-cosh (mean):', lc)
print('log-cosh with outlier (mean):', lc_outlier)

log-cosh (mean): [[0.00499171 0.00499171 0.01986814]
 [0.         0.12011451 0.        ]]
log-cosh with outlier (mean): [[4.99171019e-03 4.99171019e-03 6.10685444e+00]
 [0.00000000e+00 1.20114505e-01 6.30685377e+00]]


## Cosine distance (1 - cosine similarity)

Use for comparing directions of vectors (embeddings). Scale-invariant and bounded.

In [7]:
# Pairwise aligned vectors [..., D] -> [...]
v1 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
v2 = jnp.array([[0.0, 1.0], [1.0, 0.0], [1.0, -1.0]])
cd = bt.metric.cosine_distance(v1, v2, epsilon=1e-8)
print('Cosine distance:', cd)

# Also available: cosine_similarity (aligned) and pairwise matrix version in bt.metric.cosine_similarity (X,Y)
cs_aligned = bt.metric.cosine_similarity(v1, v2)
print('Cosine similarity:', cs_aligned)

Cosine distance: [1. 1. 1.]
Cosine similarity: [0. 0. 0.]


## Guidance

- Prefer L1/Huber/log-cosh when outliers are present or robustness is desired.
- Use L2/MSE for well-behaved noise where larger errors should be penalized quadratically.
- For embeddings, normalize implicitly via cosine distance; no need to re-scale features.
- Use `axis` to aggregate per-sample (e.g., `axis=-1`) and set `reduction` explicitly when needed.

## Pitfalls

- Ensure predictions and targets have the same shape for arithmetic losses.
- For cosine metrics, avoid zero vectors or set a small `epsilon`.
- Be explicit about `reduction` to avoid surprises (default may differ among functions).