# Tutorial 1: Classification Losses

This tutorial introduces the core classification losses in `braintools.metric` and how to pick and use them effectively:

- Binary/multilabel: `braintools.metric.sigmoid_binary_cross_entropy`
- Multiclass: `braintools.metric.softmax_cross_entropy`, `braintools.metric.softmax_cross_entropy_with_integer_labels`
- Imbalanced data: `braintools.metric.sigmoid_focal_loss`
- Regularization: `braintools.metric.smooth_labels`

All examples use JAX arrays and are shape- and type-checked.

In [1]:
import jax.numpy as jnp
import braintools

## Binary / Multilabel (Sigmoid Cross-Entropy)

Use when each class is an independent yes/no decision (not mutually exclusive).

In [2]:
# Binary/multilabel setup (elementwise BCE)
logits = jnp.array([1.0, -1.0, 0.0])        # unnormalized logits
labels = jnp.array([1.0,  0.0, 1.0])        # binary targets in {0,1}

loss = braintools.metric.sigmoid_binary_cross_entropy(logits, labels)
print(loss)          # per-element BCE, same shape as logits
print(loss.mean())   # common reduction

[0.3132617 0.3132617 0.6931472]
0.4398902


Tips
- Labels must be float (0/1) or probabilities, shape-broadcastable with `logits`.
- Use this also for multilabel multiclass problems.

## Multiclass (Softmax Cross-Entropy)

Use when classes are mutually exclusive.

In [3]:
# One-hot or probability targets
# logits shape [..., num_classes]
logits = jnp.array([[2.0, 1.0, 0.1]])
targets = jnp.array([[1.0, 0.0, 0.0]])  # one-hot

loss = braintools.metric.softmax_cross_entropy(logits, targets)
print(loss)  # shape [...]

[0.41702995]


In [4]:
# Integer labels (preferred for single-label classification)
logits = jnp.array([[2.0, 1.0, 0.1]])
labels = jnp.array([0])  # class index

loss = braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)
print(loss)

[0.41702995]


Notes
- `softmax_cross_entropy` expects float targets; `*_with_integer_labels` expects integer labels.
- Shapes: `logits[..., C]`, `labels[...]` or `targets[..., C]`.

## Focal Loss (Imbalanced Data)

Use focal loss to focus learning on hard, misclassified examples in imbalanced settings. For multilabel/binary, use `sigmoid_focal_loss`.

In [5]:
# Imbalanced binary classification example
logits = jnp.array([2.0, -1.0, 0.5, -2.0])
labels = jnp.array([1.0, 0.0, 1.0, 0.0])

# Alpha balances positive vs negative; gamma focuses on hard examples
loss = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0)
print(loss)
print(loss.mean())

# Unweighted focal (no class weighting)
loss_unweighted = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)
print(loss_unweighted)

[0.00045089 0.01699354 0.01689337 0.00135267]
0.008922619
[0.00180356 0.02265805 0.06757348 0.00180356]


Guidance
- Choose `alpha` near the positive-class frequency (e.g., 0.25).
- Increase `gamma` (1–5) to emphasize harder examples more.

## Label Smoothing (Regularization)

Smoothing prevents overconfidence by blending one-hot labels with a uniform distribution. Combine with softmax CE.

In [6]:
# One-hot labels [..., C]
labels_one_hot = jnp.array([[1.0, 0.0, 0.0]])
smoothed = braintools.metric.smooth_labels(labels_one_hot, alpha=0.1)

logits = jnp.array([[2.0, 1.0, 0.5]])
loss = braintools.metric.softmax_cross_entropy(logits, smoothed)
print(loss)
print(smoothed)

[0.54770213]
[[0.93333334 0.03333334 0.03333334]]


Rules of thumb
- `alpha` in [0.05, 0.2] is common; larger values smooth more.
- Improves calibration and robustness; may slightly lower peak accuracy if overused.

## Which Loss Should I Use?

- Binary/multilabel tasks → `sigmoid_binary_cross_entropy`
- Multiclass single-label → `softmax_cross_entropy_with_integer_labels`
- Heavily imbalanced (binary/multilabel) → `sigmoid_focal_loss`
- Overconfident models → `smooth_labels` + softmax CE

## Common Pitfalls

- Do not feed integer labels to `softmax_cross_entropy` (use the integer-label variant).
- For multilabel problems, use sigmoid-based losses (not softmax).
- Always match shapes: `logits[..., C]` with one-hot targets `[..., C]` or integer `[...]`.

## Extras: NLL with Log-Probabilities

If you already have log-probabilities, use `nll_loss` directly.

In [7]:
log_probs = jnp.log(jnp.array([0.1, 0.7, 0.2]))
target = 1
print(braintools.metric.nll_loss(log_probs, target))

-0.35667497


---

See also
- API reference: `braintools.metric` → classification functions
- Ranking and regression losses are covered in separate tutorials.