This notebook includes my implementation of the **softmax** function, which is used to map a real-valued vector to probabilities that sum to 1. I'll also leverage the **log-sum-exp trick** to handle overflow problems and achieve numerical stability.

**Softmax**

The softmax function can be defined as:

$$
S(x_i, T) = p_i = \frac{e^{x_i / T}}{\sum_j^N{e^{x_j / T}}}
$$
$$
x, p \in \mathcal{R}^N
$$

where the vector of probabilities $p$ has the property

$$ 
\sum{p} = 1
$$

This is useful, for example, in multiclass logistic regression where we want our model to output a probability distribution over $N$ classes.

**Log-Sum-Exp Trick**

One problem with the above formula for softmax is the possibility of overflow and underflow when executing with limited numerical precision. For example, if $x = [1000, -3000, -5000]$, then we will simultaneously experience overflow *and* underflow when computing $e^{x_i}$.

We can correct for this by exploiting the property of logarithms:

$$
\log{\sum{e^{x_i}}} = \log{\sum{e^{x_i} \cdot e^{m} \cdot e^{-m}}} = m + \log{\sum{e^{x_i - m}}}
$$

Because we now have a new arbitrary term m, we can prevent overflow by setting m to be the max of x. As a result, we will never have exp(some value > 0), although we may still have underflow and get one term $\approx 0$. However, the sum will still be sensible.


In [60]:
import torch

In [77]:
# regular implementation - not numerically stable
def softmax(x: torch.Tensor, temperature: float = 1) -> torch.Tensor:
    """Non-numerically stable implementation of the softmax function

    Args:
        x (torch.Tensor): input vector. Expected to be 1-D
        temperature (float): determines distribution of probability according to high/low logits. For high
            temperature, probability is spread more evenly, while for lower temperatures probability is assigned
            unevenly to higher logit values

    Raises:
        ValueError: if x has greater than 1 dimension
        ValueError: if the temperature is less than or equal to 0

    Returns:
        torch.Tensor: result of softmax operation. Sum of returned vector should be 1
    """
    if x.ndim > 1:
        raise ValueError(f'x is expected to be a 1D vector, instead got shape {x.shape}')

    if temperature <= 0:
        raise ValueError(f'temperature must be in the range (0, inf), instead got {temperature}')

    numerator = torch.exp(x / temperature)

    denominator = numerator.sum()

    probabilities = numerator / denominator

    return probabilities

In [106]:
def lse(x: torch.Tensor) -> torch.Tensor:
    return torch.log(torch.sum(torch.exp(x)))

def softmax_stable(x: torch.Tensor, temperature: float = 1) -> torch.Tensor:
    """Numerically stable implementation of the softmax function using the log-sum-exp trick

    Args:
        x (torch.Tensor): input vector. Expected to be 1-D
        temperature (float): determines distribution of probability according to high/low logits. For high
            temperature, probability is spread more evenly, while for lower temperatures probability is assigned
            unevenly to higher logit values

    Raises:
        ValueError: if x has greater than 1 dimension
        ValueError: if the temperature is less than or equal to 0

    Returns:
        torch.Tensor: result of softmax operation. Sum of returned vector should be 1
    """
    if x.ndim > 1:
        raise ValueError(f'x is expected to be a 1D vector, instead got shape {x.shape}')

    if temperature <= 0:
        raise ValueError(f'temperature must be in the range (0, inf), instead got {temperature}')

    x_temp_adjusted = x / temperature

    x_adjusted = x_temp_adjusted - x_temp_adjusted.max()

    lse_result = lse(x_adjusted)

    probabilities = torch.exp(x_adjusted - lse_result)

    return probabilities

In [114]:
# define some inputs to test for both implementations
x_test_cases = torch.FloatTensor([[0, 0, 0],
                [1, 2, 3],
                [-1000000, 1, 2],
                [1000000, 1, 2],
                [50000, 50000, 50000],
                [-50000, -50000, -50000]])

temperature_test_cases = [1, .5, .1, .000001, 10, 10_000]

In [124]:
# iterate over all combinations of test cases for both implementations and save the results
softmax_results_regular = []

softmax_results_stable = []

for x in x_test_cases:
    for temperature in temperature_test_cases:
        softmax_results_regular.append(softmax(x, temperature))
        softmax_results_stable.append(softmax_stable(x, temperature))

In [145]:
# test the quality of outputs - do they sum to 1?
softmax_regular_sum_to_one_results = []
softmax_stable_sum_to_one_results = []

for result_regular, result_stable in zip(softmax_results_regular, softmax_results_stable):
    # absolute tolerance for checking whether the sum of the probs is 1
    atol = .00001
    softmax_regular_sum_to_one_results.append(torch.isclose(result_regular.sum(), torch.tensor(1.0), atol=atol))
    softmax_stable_sum_to_one_results.append(torch.isclose(result_stable.sum(), torch.tensor(1.0), atol=atol))

softmax_regular_sum_to_one_results = torch.tensor(softmax_regular_sum_to_one_results)
softmax_stable_sum_to_one_results = torch.tensor(softmax_stable_sum_to_one_results)

p_correct_regular = softmax_regular_sum_to_one_results.float().mean().item() * 100
p_correct_stable = softmax_stable_sum_to_one_results.float().mean().item() * 100

print(f'percent of softmax results summing to 1 for different implementations: \
      \n\t-regular: {p_correct_regular} \n\t-stable (log-sum-exp): {p_correct_stable}')

percent of softmax results summing to 1 for different implementations:       
	-regular: 50.0 
	-stable (log-sum-exp): 100.0
