In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import flax
from typing import Any
import numpy as np
import functools
from einops import rearrange
import tqdm
import time

# Classifiers

> A **classifier** is a model that is trained to classify data into one of N distinct categories.

One of the simplest possible models to train is the **categorical classifier**. In this setting, we have a labelled dataset consisting of `(data, label) / (x,y)` pairs, where each **label** (y) is an integer representing one of N distinct classes. We would then like to train our classifier model such that it correctly predicts which class a given datapoint (x) falls into.

## Probability Maximization

To dervive the classification objective, we'll first note that **a classifier is a probability distribution over classes**. We want to fit the parameters of our classifier model to maximize the probability of samples from our dataset. To do this, we minimize the negative log-likelihood:

$$
L(\theta) = - \sum_{x, y \sim D} \log p_\theta(y|x).
$$

This loss function is also referred to as the **cross-entropy loss** in machine learning literature. The name refers to the fact that this loss is in fact equivalent to information-theoretic cross-entropy between the data distribution and the model distribution.


## Model Structure

The final output of a neural network classifier must represent a probability distribution over classes. In the discrete categorical setting, this is simple -- the model can output a vector of probabilities, one for each possible class.

However, probability distributions must be positive and sum to one, and it is tricky to enforce these contraints in a neural network setting. Therefore, models are instead trained to output **logits**, which are unconstrained and real-valued. To get the probabilities from the logits, we exponentiate then explicitly normalize the values, dividing by their total sum. This operation is so common that is gets a special name, the **softmax** operation.

In [None]:
def softmax(logits): # [num_classes]
    return jnp.exp(logits) / jnp.sum(jnp.exp(logits))

Putting these together, we can now implement a basic classification loss function:

In [None]:
class Classifier(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Dense(10)(x)

def loss(params, inputs, labels):
    logits = Classifier().apply({'params': params}, inputs)
    one_hot_labels = jax.nn.one_hot(labels, 10)
    probs = softmax(logits)
    log_probs = jnp.log(probs)
    return -jnp.mean(jnp.sum(one_hot_labels * log_probs, axis=-1))

## Example: Imagenet Classification

To get a feel for a practical-size classifier, let's train a model that learns to classify the Imagenet dataset, a set of 1.2 million images with 1000 classes. We will use our homebrew `jaxtransformer` library to provide the backbone model.

In [None]:
from jaxtransformer import TransformerBackbone
from jaxtransformer.modalities import Patch

## Overfitting

## Why do we maximize log-probabilities instead of probabilities?

Because numerically, log-probabilities are easier to work with. First, note that the `log` function is monotonically increasing, so maximizing log-probabilities is equivalent to maximizing probabilites. Working in log space brings a number of numerical benefits. The log-probability of an independent joint distribution is the *sum* of individual log-probs, whereas raw probabilities would need to be multplied. Multiplication, especially of small values, will quickly become numerically unstable. We would much rather work with

$$
\sum_{y,x} \; \log p_\theta(y|x) \qquad \text{instead of} \qquad \prod_{y,x} \; p_\theta(y|x),
$$

and the optimal parameters are the same in either case.

## Why so some libraries subtract a constant in their softmax implementations?

The raw softmax function can be numerically unstable, as it involves an `exp` operator. To alleviate this problem, practical machine learning libraries tend to make use of a trick in the softmax equation: since the softmax is a ratio, we can scale all terms by a constant and get the same result. In logit-space, this means we can subtract a constant **before the exp** and get the same result. Often, the maximum logit is subtracted from the logit vector, which reduces the change that `exp` is called on a large number.

## What is the difference between binary cross-entropy and cross-entropy?

The term **binary cross-entropy** often arises when classifying data into two possible categories. The equation for binary cross-entropy is in fact the same as cross-entropy, just written in a slightly different way.

$$
\underbrace{\log p_\theta(y|x)}_{\text{cross-entropy}} = \sum_{y'} \log p_\theta(y'|x) * 𝟙(y' = y)  = \underbrace{\log p_\theta(y_1) * 𝟙(y_1 = y) + \log p_\theta(y_0) * 𝟙(y_0 = y)}_{\text{binary cross-entropy}}
$$

Additionally, the binary equivalent of the softmax activation is the **sigmoid** function, which gives the probability $p(y_1)$ from a real-valued logit, using the fact that for binary classification, $p(y_0) = 1 - p(y_1)$. Using $h_n$ to refer to logits for class labels $y_n$:

$$
\begin{align}
p(y) = \frac{exp(h)}{\sum_{h'} exp(h')} \qquad \rightarrow \qquad p(y_1) & = \frac{exp(h_1)}{exp(h_1) + exp(h_0)} \quad \text{(binary case)}. \\
& = \frac{1}{1 + exp(h_1 - h_0)} \\
& = \frac{1}{1 + e^{h}} \quad \text{(sigmoid function)}.
\end{align}
$$

Note for the binary case, probabilty is a function purely of the difference between logits $h_1$ and $h_0$, so we only need to learn a single logit.

## Can classification labels be not one-hot?

Yes, it's perfectly valid to have a dataset where labels are not one-hot, but rather full probability vectors. This is when the interpretation of minimizing cross-entropy between two distributions comes in helpful:

$$
L(\theta) = \sum_{x \sim D} \sum_{y'} \log p_\theta(y'|x) * p^*(y')
$$
where $p^(y')$ represents the probability vector given by the dataset, or a teacher mdoel, etc. Often, the inner sum over possible labels is computed in vectorized form as a dot product. Full probability targets are most often used in model distillation, to copy the behavior of a teacher model into a student model.

# Classifiers

- Problem Setting
- Cross-Entropy Loss
- Overfitting
