# How-to Log Sum Exp

Using flash-attention intensively you will at some point hear about `lse` values being returend. `lse` stands for "log sum exp" and can be used to compute softmax (and thereby also attention) in a blockwise and stable fashion. This notebook aims to explain how this works.

Let's start by defining a naive softmax function ..

In [5]:
from typing import Tuple, Sequence
import torch
import math

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    return x.exp() / x.exp().sum()

... and verify that its output matches the output of the official `torch.softmax()` function.

In [6]:
x = torch.randn(10)  # generate normally distributed random numbers
a = torch.softmax(x, dim=-1) # reference output
b = naive_softmax(x) # our naive version

print("a", a)
print("b", b)
print("allclose", torch.allclose(a, b, atol=1e-6))

a tensor([0.0323, 0.1455, 0.0659, 0.3275, 0.0416, 0.1432, 0.0871, 0.0258, 0.0234,
        0.1077])
b tensor([0.0323, 0.1455, 0.0659, 0.3275, 0.0416, 0.1432, 0.0871, 0.0258, 0.0234,
        0.1077])
allclose True


Our naive softmax function has a problem with numerical stability when it get input vectors with larger elements:

In [7]:
naive_softmax(x * 100)

tensor([0., 0., 0., nan, 0., 0., 0., 0., 0., 0.])

But before we will look into how to fix this let's first look at how a blockwise computation of softmax can be implemented with `naive_softmax()`. We generate a vector and split it into two chunks of equal size and compute softmax of the chunks individually.

In [8]:
x = torch.randn(10)

x1,x2 = torch.chunk(x, 2)
s1 = naive_softmax(x1)
s2 = naive_softmax(x2)

print("We have:")
print(f"s1 = {s1}")
print(f"s2 = {s2}")

target = naive_softmax(x)
print("We want:")
print(f"target = {target}")

We have:
s1 = tensor([0.1469, 0.2743, 0.1178, 0.3475, 0.1134])
s2 = tensor([0.0403, 0.4899, 0.1561, 0.2785, 0.0353])
We want:
target = tensor([0.0721, 0.1347, 0.0578, 0.1706, 0.0557, 0.0205, 0.2494, 0.0795, 0.1418,
        0.0180])


If we look at `naive_softmax()` we note that its output has been divided by `x.exp().sum()`. We can call this the "sum exp" value (note the similarity to "log sum exp") and we can use it to "undo" the softmax normalization and thereby compute combine multiple softmax chunks if we have this vaue for each chunk.

In [9]:
se_x1 = x1.exp().sum()
se_x2 = x2.exp().sum()
s1_corrected = s1 * se_x1 / (se_x1 + se_x2)
s2_corrected = s2 * se_x2 / (se_x1 + se_x2)

print("After correction with help of lse values:")
s_combined = torch.cat([s1_corrected, s2_corrected])
print("s_combined", s_combined)

print("allclose(s_combined, target):", torch.allclose(s_combined, target))


After correction with help of lse values:
s_combined tensor([0.0721, 0.1347, 0.0578, 0.1706, 0.0557, 0.0205, 0.2494, 0.0795, 0.1418,
        0.0180])
allclose(s_combined, target): True


... but is this helpful at all? Yes, and it becomes more obivous when we realize that we can return this value from our softmax function and we can do the correction in a blockwise fashion in a loop by accumulating the `sum_exp` value:

In [10]:
def naive_softmax2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    sum_exp = x.exp().sum()
    return x.exp() / sum_exp, sum_exp


def naive_blockwise_softmax(blocks: Sequence[torch.Tensor]) -> torch.Tensor:
    out = []
    sum_exp = 0
    for block in blocks:
        block_softmax, block_sum_exp = naive_softmax2(block)
        for o in out:
            o *= sum_exp / (sum_exp + block_sum_exp)
        
        out.append(block_softmax * block_sum_exp / (block_sum_exp + sum_exp))
        sum_exp += block_sum_exp
        
    return torch.cat(out)
    

x_long = torch.randn(20)
chunks = torch.chunk(x_long, 4)
print(naive_blockwise_softmax(chunks))
print(torch.softmax(x_long, dim=-1))


tensor([0.0476, 0.0185, 0.1522, 0.0062, 0.0050, 0.0552, 0.1088, 0.0218, 0.0304,
        0.1719, 0.0211, 0.0510, 0.1319, 0.0152, 0.0238, 0.0267, 0.0116, 0.0490,
        0.0310, 0.0209])
tensor([0.0476, 0.0185, 0.1522, 0.0062, 0.0050, 0.0552, 0.1088, 0.0218, 0.0304,
        0.1719, 0.0211, 0.0510, 0.1319, 0.0152, 0.0238, 0.0267, 0.0116, 0.0490,
        0.0310, 0.0209])


OK, then now let's look at the numerical stability of softmax. First we can observe a interesting property of the softmax function: its output is shift/translation invariant (i.e. `f(x+a)=f(x)`):

In [11]:
x = torch.randn(8)
print(naive_softmax(x))
print(naive_softmax(x+5))
print(naive_softmax(x-3))

tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])
tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])
tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])


This porperty allows us to deal with problematic large inputs simply by subtracting their maximum:

In [12]:
def stable_softmax(x):
    m = x.max()
    return (x-m).exp() / (x-m).exp().sum()

This "stable" function now can also deal with larger value that were problematic for our naive function:

In [13]:
large_input = torch.randn(10) * 100

print("naive: ", naive_softmax(large_input))
print("stable: ", stable_softmax(large_input))
print("torch: ", torch.softmax(large_input, dim=-1))


naive:  tensor([0., 0., 0., 0., 0., 0., nan, 0., 0., 0.])
stable:  tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 5.2809e-36, 0.0000e+00, 1.7820e-34,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
torch:  tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 5.2809e-36, 0.0000e+00, 1.7820e-34,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])


In [14]:
def stable_softmax2(x):
    """returns softmax result and log sum exp"""
    m = x.max()
    a = (x - m).exp()
    b = a.sum()
    lse = m + torch.log(b)
    return a / b, lse

Again we can now use this to combine two softmax block results, but to do it in the same way as before we would need to calculate the exp() values.. which is as we know numerically not stable:

In [15]:
x = torch.randn(20)

a = torch.softmax(x, dim=-1)

x1, x2 = x.chunk(2)

b1, lse1 = stable_softmax2(x1)
b2, lse2 = stable_softmax2(x2)

c1 = b1 * torch.exp(lse1) / (torch.exp(lse1) + torch.exp(lse2))
c2 = b2 * torch.exp(lse2) / (torch.exp(lse1) + torch.exp(lse2))

print(a)
print(torch.cat([c1, c2]), torch.allclose(a, torch.cat([c1, c2])))


tensor([0.2087, 0.0162, 0.0830, 0.0191, 0.0285, 0.0618, 0.0780, 0.0043, 0.0012,
        0.0224, 0.0149, 0.1135, 0.0741, 0.0192, 0.0053, 0.0331, 0.0273, 0.0171,
        0.0762, 0.0962])
tensor([0.2087, 0.0162, 0.0830, 0.0191, 0.0285, 0.0618, 0.0780, 0.0043, 0.0012,
        0.0224, 0.0149, 0.1135, 0.0741, 0.0192, 0.0053, 0.0331, 0.0273, 0.0171,
        0.0762, 0.0962]) True


But luckily log & exp are to the rescue:

In [16]:
d1 = b1 * torch.exp(-torch.log(1 + torch.exp(lse2 - lse1)))
d2 = b2 * torch.exp(-torch.log(1 + torch.exp(lse1 - lse2)))
print(a)
print(torch.cat([d1, d2]), torch.allclose(a, torch.cat([d1, d2])))

tensor([0.2087, 0.0162, 0.0830, 0.0191, 0.0285, 0.0618, 0.0780, 0.0043, 0.0012,
        0.0224, 0.0149, 0.1135, 0.0741, 0.0192, 0.0053, 0.0331, 0.0273, 0.0171,
        0.0762, 0.0962])
tensor([0.2087, 0.0162, 0.0830, 0.0191, 0.0285, 0.0618, 0.0780, 0.0043, 0.0012,
        0.0224, 0.0149, 0.1135, 0.0741, 0.0192, 0.0053, 0.0331, 0.0273, 0.0171,
        0.0762, 0.0962]) True


To understand why `b1 * torch.exp(lse1) / (torch.exp(lse1) + torch.exp(lse2))` is equal to `b1 * torch.exp(-torch.log(1 + torch.exp(lse2-lse1)))` we remember school math basics:

In [17]:
a = 5
b = 3

print("math.exp(5)/math.exp(3) =", math.exp(5) / math.exp(3))
print("math.exp(5 - 3) =", math.exp(5 - 3))

print("a/(a+b) =", a / (a+b))
print("1/(1+b/a) =", 1 / (1+b/a))

math.exp(5)/math.exp(3) = 7.38905609893065
math.exp(5 - 3) = 7.38905609893065
a/(a+b) = 0.625
1/(1+b/a) = 0.625


With the fresh knowledge about softmax we can now take a look at the `update()` function that is used in the ring-flash-attention implementation:

In [18]:
def _update_out_and_lse(
    out: torch.Tensor,
    lse: torch.Tensor,
    block_out: torch.Tensor,
    block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    block_out = block_out.to(torch.float32)
    block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

    new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
    out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out

    lse = new_lse
    return out, lse