A couple of useful metrics and losses for multiclass classification. I wrote them for my imbalanced dataset.

In [1]:
import jax
import jax.numpy as jnp

# If micro precision is used for imbalanced data, the contributions from the majority
# class will dominate. Macro avoids it.
def macro_precision(y_true, y_pred):
    threshold = 0.5
    y_pred = jnp.where(y_pred > threshold, 1.0, 0.0)

    precisions = []

    # Get precision for each class and then average them.
    for i in range(y_true.shape[1]):
        class_true = y_true[:, i]
        class_pred = y_pred[:, i]
        true_positives = jnp.sum(class_true * class_pred)
        predicted_positives = jnp.sum(class_pred)
        precision = true_positives / (predicted_positives + jnp.finfo(float).eps)
        precisions.append(precision)
    return jnp.mean(jnp.stack(precisions))


# When both alpha (class weights) and gamma (focusing parameter) are used,
# it's the most effective. I use `compute_class_weight` from sklearn to compute
# alpha.
def categorical_focal_cross_entropy(logits, labels_one_hot, alpha, gamma=2.0):
    probs = jax.nn.softmax(logits)
    p_t = jnp.sum(probs * labels_one_hot, axis=-1)
    alpha_t = alpha[jnp.argmax(labels_one_hot, axis=-1)]
    loss = -alpha_t * jnp.power(1 - p_t, gamma) * jnp.log(p_t + 1e-10)
    return jnp.mean(loss)
