# How-to Log Sum Exp

Using flash-attention intensively you will at some point hear about `lse` values being returend. `lse` stands for "log-sum-exp" and can be used to compute softmax (and thereby also attention) in a blockwise and stable fashion. This notebook aims to explain how this works.

Let's start by defining a naive softmax function ..

In [1]:
import torch

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    return x.exp() / x.exp().sum()

.. and verify that its output matches the output of the official `torch.softmax()` function.

In [2]:
x = torch.randn(10)  # generate normally distributed random numbers
a = torch.softmax(x, dim=-1) # reference output
b = naive_softmax(x) # our naive version

print("a", a)
print("b", b)
print("allclose", torch.allclose(a, b, atol=1e-6))

a tensor([0.0695, 0.0043, 0.0450, 0.0530, 0.0109, 0.0142, 0.0665, 0.6751, 0.0143,
        0.0472])
b tensor([0.0695, 0.0043, 0.0450, 0.0530, 0.0109, 0.0142, 0.0665, 0.6751, 0.0143,
        0.0472])
allclose True


Our naive softmax function has a problem with numerical stability when it gets input vectors with larger elements:

In [3]:
naive_softmax(x * 100)

tensor([0., 0., 0., 0., 0., 0., 0., nan, 0., 0.])

Before we discuss how to fix this let's first try to compute softmax in a blockwise fashion.
Let's start by generating a random vector and split it into two chunks of equal size and compute softmax on these chunks individually.

In [5]:
x = torch.randn(10)

x1,x2 = torch.chunk(x, 2)
s1 = naive_softmax(x1)
s2 = naive_softmax(x2)

print("We have:")
print(f"s1 = {s1}")
print(f"s2 = {s2}")

target = naive_softmax(x)
print("We want:")
print(f"target = {target}")

We have:
s1 = tensor([0.7998, 0.0296, 0.1147, 0.0265, 0.0294])
s2 = tensor([0.1985, 0.3296, 0.0460, 0.1054, 0.3205])
We want:
target = tensor([0.4914, 0.0182, 0.0705, 0.0163, 0.0181, 0.0766, 0.1271, 0.0177, 0.0406,
        0.1236])


If we look at `naive_softmax()` we note that its output has been divided by `x.exp().sum()`. We can call this the "sum exp" value (note the similarity to "log sum exp") and we can use it to "undo" the softmax normalization and thereby compute combine multiple softmax chunks if we have this vaue for each chunk.

In [6]:
sum_exp_x1 = x1.exp().sum()
sum_exp_x2 = x2.exp().sum()
s1_corrected = s1 * sum_exp_x1 / (sum_exp_x1 + sum_exp_x2)
s2_corrected = s2 * sum_exp_x2 / (sum_exp_x1 + sum_exp_x2)

print("After correction with help of sum_exp values:")
s_combined = torch.cat([s1_corrected, s2_corrected])
print("s_combined", s_combined)

print("allclose(s_combined, target):", torch.allclose(s_combined, target))


After correction with help of sum_exp values:
s_combined tensor([0.4914, 0.0182, 0.0705, 0.0163, 0.0181, 0.0766, 0.1271, 0.0177, 0.0406,
        0.1236])
allclose(s_combined, target): True


... but is this helpful at all? Yes, and it becomes more obivous when we realize that we can return this value from our softmax function and we can do the correction in a blockwise fashion in a loop by accumulating the `sum_exp` value:

In [10]:
from typing import Tuple, Sequence

def naive_softmax2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    sum_exp = x.exp().sum()
    return x.exp() / sum_exp, sum_exp


def naive_blockwise_softmax(blocks: Sequence[torch.Tensor]) -> torch.Tensor:
    total_sum_exp = 0
    blocks_out = []
    for block in blocks:
        block_softmax, block_sum_exp = naive_softmax2(block)
        blocks_out.append((block_softmax, block_sum_exp))
        total_sum_exp += block_sum_exp

    out = []
    for block_softmax, block_sum_exp in blocks_out:
        out.append(block_softmax * block_sum_exp / total_sum_exp)

    return torch.cat(out)

x_long = torch.randn(20)
chunks = torch.chunk(x_long, 4)
a = naive_blockwise_softmax(chunks)
b = torch.softmax(x_long, dim=-1)
print("a", a)
print("b", b)
print("allclose:", torch.allclose(a, b))


a tensor([0.0608, 0.0217, 0.0446, 0.0180, 0.0289, 0.2553, 0.0101, 0.0150, 0.0699,
        0.0231, 0.0672, 0.0388, 0.0186, 0.0126, 0.0731, 0.0132, 0.0191, 0.0362,
        0.1688, 0.0050])
b tensor([0.0608, 0.0217, 0.0446, 0.0180, 0.0289, 0.2553, 0.0101, 0.0150, 0.0699,
        0.0231, 0.0672, 0.0388, 0.0186, 0.0126, 0.0731, 0.0132, 0.0191, 0.0362,
        0.1688, 0.0050])
allclose: True


OK, then now let's look at the numerical stability of softmax. First we can observe a interesting property of the softmax function: its output is shift/translation invariant (i.e. `f(x+a)=f(x)`):

In [11]:
x = torch.randn(8)
print(naive_softmax(x))
print(naive_softmax(x+5))
print(naive_softmax(x-3))

tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])
tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])
tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])


This porperty allows us to deal with problematic large inputs simply by subtracting their maximum:

In [13]:
def stable_softmax(x):
    m = x.max()
    return (x-m).exp() / (x-m).exp().sum()

This "stable" function now can also deal with larger value that were problematic for our naive function:

In [14]:
large_input = torch.randn(10) * 100

print("naive: ", naive_softmax(large_input))
print("stable: ", stable_softmax(large_input))
print("torch: ", torch.softmax(large_input, dim=-1))


naive:  tensor([0., nan, 0., 0., 0., 0., 0., 0., 0., 0.])
stable:  tensor([4.6243e-44, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
torch:  tensor([4.6243e-44, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])


In [15]:
def stable_softmax2(x):
    """returns softmax result and log sum exp"""
    m = x.max()
    a = (x - m).exp()
    b = a.sum()
    lse = m + torch.log(b)
    return a / b, lse

Again we can now use this to combine two softmax block results, but to do it in the same way as before we would need to calculate the exp() values.. which is as we know numerically not stable:

In [16]:
x = torch.randn(20)

a = torch.softmax(x, dim=-1)

x1, x2 = x.chunk(2)

b1, lse1 = stable_softmax2(x1)
b2, lse2 = stable_softmax2(x2)

c1 = b1 * torch.exp(lse1) / (torch.exp(lse1) + torch.exp(lse2))
c2 = b2 * torch.exp(lse2) / (torch.exp(lse1) + torch.exp(lse2))

print(a)
print(torch.cat([c1, c2]), torch.allclose(a, torch.cat([c1, c2])))


tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,
        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,
        0.0044, 0.0677])
tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,
        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,
        0.0044, 0.0677]) True


But luckily we can rewrite it (`a/(a+b) = 1/(1 + b/a)`) and replace the exp-division by a sbtration of log-values (`exp(a)/exp(b) = exp(a-b)`).

In [19]:
d1 = b1 / (1 + torch.exp(lse2 - lse1))
d2 = b2 / (1 + torch.exp(lse1 - lse2))
print(a)
print(torch.cat([d1, d2]))
print("allclose: ", torch.allclose(a, torch.cat([d1, d2])))

tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,
        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,
        0.0044, 0.0677])
tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,
        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,
        0.0044, 0.0677])
allclose:  True


With the fresh knowledge about softmax we can now take a look at the `update()` function that is used in the ring-flash-attention implementation:

In [18]:
def _update_out_and_lse(
    out: torch.Tensor,
    lse: torch.Tensor,
    block_out: torch.Tensor,
    block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    block_out = block_out.to(torch.float32)
    block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

    new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
    out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out

    lse = new_lse
    return out, lse