# Stable Softmax and Cross Entropy from Scratch

- See [Blog (Jay Mody): Numerically Stable Softmax and Cross Entropy](https://jaykmody.com/blog/stable-softmax/) for nice write-up
- Using `dim=1` by default because for multi-dimensional inputs, pytorch expects the channel dimension in dim=1

### Setup

In [6]:
import torch
import numpy
import torch.nn.functional as F
import numpy as np

# Random tensor, shape (batch_size, num_classes)
batch_size = 16
num_classes = 64

x_np = np.random.rand(batch_size, num_classes).astype(np.float32) * 100 - 50  # U(0,1] -> U(-50,50]
y_np = np.random.randint(low=0, high=num_classes, size=(batch_size,))  # [low,high-1]

# OR
#   x_np = torch.rand((batch_size, num_classes), dtype=torch.float64).numpy()  # Differnet API as np
#   y_np = torch.randint(low=0, high=num_classes, size=(batch_size,))  # Same API as np

x_pt = torch.tensor(x_np)
y_pt = torch.tensor(y_np)

print(f"x: {x_pt}")
print(f"y: {y_pt}")

x: tensor([[ 39.1881, -34.0590, -27.0345,  ..., -36.3017,  18.9438, -27.8977],
        [-36.6272,   2.8195,  17.4534,  ...,  41.3051, -39.7945, -44.1213],
        [ 29.3098,  -9.7967,  38.6168,  ...,  31.2592,  31.0836, -22.3497],
        ...,
        [ -9.5909, -36.8363,  24.4595,  ..., -20.8998,   3.5788, -29.3600],
        [ 28.8947, -48.7126, -35.8005,  ...,  27.6717,  15.4610, -30.0755],
        [ 16.3976, -26.9319,  12.1767,  ..., -32.3698,  10.6299,  39.5487]])
y: tensor([23, 21,  4,  5, 30, 32, 62, 38, 51, 36, 53,  5, 32, 31, 35, 29])


## Softmax

- Definition: $\text{softmax}(x)_k = \frac{\exp^{x_k}}{\sum_j \exp^{x_j}}$
- Stable version:
    - Multiply by some constant $\frac{C}{C}$
    - $\text{softmax}(x)_k = \frac{\exp^{x_k}}{\sum_j \exp^{x_j}}$
    - $\text{softmax}(x)_k = \frac{C \cdot \exp^{x_k}}{C \cdot \sum_j \exp^{x_j}}$
    - $\text{softmax}(x)_k = \frac{\exp^{x_k + \log C}}{\sum_j \exp^{x_j + \log C}}$
    - Let $\log C = -\max(x)$
    - $\text{softmax}(x)_k = \frac{\exp^{x_k -\max(x)}}{\sum_j \exp^{x_j -\max(x)}}$

In [8]:
# PyTorch
smax_pt = F.softmax(x_pt, dim=1)

In [9]:
def softmax_torch(logits: torch.Tensor, dim: int = 1) -> torch.Tensor:
    # logits.shape: (batch_size, vocab_size) = (N,K)
    
    # Stable softmax to get probs
    #   >> softmax(x)_k = e^(x_k - max(x)) / sum_j e^(x_j - max(x))    
    
    # Subtract max for stability, then take exponential
    #   torch.amax(x) doesn't return indices; torch.max(x, keepdim=True) returns (vals, ind)
    exp_logits = torch.exp(logits - torch.amax(logits, dim=dim, keepdim=True))    
    
    # Normalize to turn into probs
    probs = exp_logits / torch.sum(exp_logits, dim=dim, keepdim=True)
    
    return probs  # (N,K) -> (N,K)

def softmax_numpy(logits: np.ndarray, dim: int = 1) -> np.ndarray:
    # Numpy API differences here:
    #   - `keepdims` instead of `keepdim`
    #   - `axis` instead of `dim`
    exp_logits = np.exp(logits - logits.max(axis=dim, keepdims=True))
    probs = exp_logits / np.sum(exp_logits, axis=dim, keepdims=True)    
    return probs

smax_ours_pt = softmax_torch(x_pt, dim=1)
smax_ours_np = softmax_numpy(x_np, dim=1)
print(f"Allclose torch: {torch.allclose(smax_pt, smax_ours_pt)}")
print(f"Allclose numpy: {np.allclose(smax_pt.numpy(), smax_ours_np)}")

Allclose torch: True
Allclose numpy: True


## Cross Entropy

- Definition: $H(p, q) = - \sum_i p_i \cdot \log q_i$ where $p,q$ are prob. distributions
    - If $p_i == 1 iff i == y else 0$
        - $H(p, q) = - p_y \cdot \log q_y$
        - $H(p, q) = - \log q_y$
        - $H(p, q) = - \log \text{softmax}(x)_y$
- Plug in our definition of stable softmax
    - $H(p, q) = - \log \text{softmax}(x)_y$
    - $H(p, q) = - \log \frac{\exp^{x_y -\max(x)}}{\sum_j \exp^{x_j -\max(x)}}$
    - $H(p, q) = - (\log (\exp^{x_y -\max(x)}) - \log \sum_j \exp^{x_j -\max(x)})$
    - $H(p, q) = - (x_y -\max(x) - \log \sum_j \exp^{x_j -\max(x)})$

In [12]:
ce_pt = F.cross_entropy(x_pt, y_pt)
ce_pt

tensor(44.4969)

In [13]:
def cross_entropy_torch_v1(logits: torch.Tensor, targets: torch.Tensor):
    # logits.shape: (N,K) for (batch_size, num_classes)
    # targets.shpae: (N,) for (num_classes,)
    
    # Via stable cross entropy
    #   >>> - log(softmax(x)_k)  where k == correct class
    #   >>> - log( (e^(x_k - max(x)) / sum_j e^(x_j - max(x)) )  
    #   >>> - (x_k - max(x) - log( sum_j e^(x_j - max(x)) ))
    #   >>> - x_k + max(x) + log( sum_j e^(x_j - max(x)) ))    
    max_logits = torch.amax(logits, dim=1, keepdim=True)  # (N,K) -> (N,K)
    log_sum = torch.log(torch.sum(torch.exp(logits - max_logits), dim=1))  # (N,K) -> (N,)
    logits_y = logits[torch.arange(logits.shape[0], device=logits.device), targets]  # (N,K) -> (N,)
    ce = torch.mean(-logits_y + max_logits.squeeze(dim=1) + log_sum)
    return ce

def cross_entropy_numpy_v1(logits: np.ndarray, targets: np.ndarray):
    # Numpy API differences here:
    #   - `keepdims` instead of `keepdim`
    #   - `axis` instead of `dim`
    max_logits = np.max(logits, axis=1, keepdims=True)  # (N,K) -> (N,K)
    log_sum = np.log(np.sum(np.exp(logits - max_logits), axis=1))  # (N,K) -> (N,)    
    logits_y = logits[np.arange(logits.shape[0]), targets]  # (N,K) -> (N,)
    ce = np.mean(-logits_y + max_logits.squeeze(axis=1) + log_sum)
    return ce

def cross_entropy_torch_v2(logits: torch.Tensor, targets: torch.Tensor):
    # logits.shape: (N,K) for (batch_size, num_classes)
    # targets.shpae: (N,) for (num_classes,)
    
    # Via cross_entropy of stable softmax
    #   >>> - log(softmax(x)_k)  where k == correct class
    smax = softmax_torch(logits, dim=1)
    smax_y = smax[torch.arange(logits.shape[0]), targets]  # (N,K) -> (N,)
    ce = torch.mean(-torch.log(smax_y))
    return ce

def cross_entropy_numpy_v2(logits: np.ndarray, targets: np.ndarray):
    # No numpy differences here besides `torch` -> `np`
    #   - Only b/c we're using `softmax_numpy`
    #   - Otherwise same `keepdim` -> `keepdims` and `dim` -> `axis` would apply
    smax = softmax_numpy(logits, dim=1)
    smax_y = smax[np.arange(logits.shape[0]), targets]  # (N,K) -> (N,)
    ce = np.mean(-np.log(smax_y))
    return ce

ce_ours_torch_v1 = cross_entropy_torch_v1(x_pt, y_pt)  # Pass in x_pt, y_pt for torch.Tensors
ce_ours_torch_v2 = cross_entropy_torch_v2(x_pt, y_pt)  # Pass in x_pt, y_pt for torch.Tensors

ce_ours_numpy_v1 = cross_entropy_numpy_v1(x_np, y_np)  # Pass in x_np, y_np for np.ndarrays
ce_ours_numpy_v2 = cross_entropy_numpy_v2(x_np, y_np)  # Pass in x_np, y_np for np.ndarrays

print(f'Allclose torch v1: {torch.allclose(ce_pt, ce_ours_torch_v1)}')
print(f'Allclose torch v2: {torch.allclose(ce_pt, ce_ours_torch_v2)}')
print(f'Allclose numpy v1: {numpy.allclose(ce_pt, ce_ours_numpy_v1)}')
print(f'Allclose numpy v2: {numpy.allclose(ce_pt, ce_ours_numpy_v2)}')

Allclose torch v1: True
Allclose torch v2: True
Allclose numpy v1: True
Allclose numpy v2: True
