# How-to Log Sum Exp

flash-attention  的softmax

In [1]:

## 简单的softmax
import torch

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    return x.exp() / x.exp().sum()

In [2]:
## 验证输出

x = torch.randn(10)  # generate normally distributed random numbers
a = torch.softmax(x, dim=-1) # reference output
b = naive_softmax(x) # our naive version

print("a", a)
print("b", b)
print("allclose", torch.allclose(a, b, atol=1e-6))

a tensor([0.3142, 0.0223, 0.0794, 0.0450, 0.0121, 0.0764, 0.0184, 0.1597, 0.0350,
        0.2375])
b tensor([0.3142, 0.0223, 0.0794, 0.0450, 0.0121, 0.0764, 0.0184, 0.1597, 0.0350,
        0.2375])
allclose True


In [3]:
# 当数值较大时，存在输出不稳定
naive_softmax(x * 100)

tensor([nan, 0., 0., 0., 0., 0., 0., nan, 0., nan])

In [4]:
# 将向量切分为两个小的均等的块进行计算

x = torch.randn(10)

x1,x2 = torch.chunk(x, 2)
s1 = naive_softmax(x1)
s2 = naive_softmax(x2)

print("We have:")
print(f"s1 = {s1}")
print(f"s2 = {s2}")

target = naive_softmax(x)
print("We want:")
print(f"target = {target}")

We have:
s1 = tensor([0.0555, 0.1669, 0.2362, 0.5046, 0.0368])
s2 = tensor([0.0629, 0.4399, 0.0854, 0.1635, 0.2482])
We want:
target = tensor([0.0415, 0.1248, 0.1767, 0.3775, 0.0276, 0.0159, 0.1108, 0.0215, 0.0412,
        0.0625])


In [5]:

## 合并,可以将 sum exp 称之为 log sum exp

sum_exp_x1 = x1.exp().sum()
sum_exp_x2 = x2.exp().sum()
s1_corrected = s1 * sum_exp_x1 / (sum_exp_x1 + sum_exp_x2)
s2_corrected = s2 * sum_exp_x2 / (sum_exp_x1 + sum_exp_x2)

print("After correction with help of sum_exp values:")
s_combined = torch.cat([s1_corrected, s2_corrected])
print("s_combined", s_combined)

print("allclose(s_combined, target):", torch.allclose(s_combined, target))

After correction with help of sum_exp values:
s_combined tensor([0.0415, 0.1248, 0.1767, 0.3775, 0.0276, 0.0159, 0.1108, 0.0215, 0.0412,
        0.0625])
allclose(s_combined, target): True


In [7]:
from typing import Tuple, Sequence

def naive_softmax2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    sum_exp = x.exp().sum()
    return x.exp() / sum_exp, sum_exp


def naive_blockwise_softmax(blocks: Sequence[torch.Tensor]) -> torch.Tensor:
    total_sum_exp = 0
    blocks_out = []
    for block in blocks:
        block_softmax, block_sum_exp = naive_softmax2(block)
        blocks_out.append((block_softmax, block_sum_exp))
        total_sum_exp += block_sum_exp

    out = []
    for block_softmax, block_sum_exp in blocks_out:
        out.append(block_softmax * block_sum_exp / total_sum_exp)

    return torch.cat(out)

x_long = torch.randn(20)
chunks = torch.chunk(x_long, 4)
a = naive_blockwise_softmax(chunks)
b = torch.softmax(x_long, dim=-1)
print("a", a)
print("b", b)
print("allclose:", torch.allclose(a, b))


a tensor([0.3531, 0.0043, 0.0408, 0.0075, 0.0528, 0.1158, 0.0112, 0.0089, 0.0775,
        0.0209, 0.0083, 0.0081, 0.0206, 0.0078, 0.0427, 0.0377, 0.0243, 0.0412,
        0.1036, 0.0129])
b tensor([0.3531, 0.0043, 0.0408, 0.0075, 0.0528, 0.1158, 0.0112, 0.0089, 0.0775,
        0.0209, 0.0083, 0.0081, 0.0206, 0.0078, 0.0427, 0.0377, 0.0243, 0.0412,
        0.1036, 0.0129])
allclose: True


In [8]:
x = torch.randn(8)
print(naive_softmax(x))
print(naive_softmax(x+5))
print(naive_softmax(x-3))

tensor([0.0115, 0.2500, 0.0473, 0.0237, 0.2880, 0.0985, 0.1659, 0.1152])
tensor([0.0115, 0.2500, 0.0473, 0.0237, 0.2880, 0.0985, 0.1659, 0.1152])
tensor([0.0115, 0.2500, 0.0473, 0.0237, 0.2880, 0.0985, 0.1659, 0.1152])


This porperty allows us to deal with problematic large inputs simply by subtracting their maximum:

In [9]:
def stable_softmax(x):
    m = x.max()
    return (x-m).exp() / (x-m).exp().sum()

This "stable" function now can also deal with larger value that were problematic for our naive function:

In [10]:
large_input = torch.randn(10) * 100

print("naive: ", naive_softmax(large_input))
print("stable: ", stable_softmax(large_input))
print("torch: ", torch.softmax(large_input, dim=-1))


naive:  tensor([nan, 0., 0., 0., 0., nan, 0., nan, 0., 0.])
stable:  tensor([1.0000e+00, 0.0000e+00, 2.8026e-45, 0.0000e+00, 1.1210e-44, 7.7510e-38,
        0.0000e+00, 2.6989e-22, 0.0000e+00, 0.0000e+00])
torch:  tensor([1.0000e+00, 0.0000e+00, 2.8026e-45, 0.0000e+00, 1.1210e-44, 7.7510e-38,
        0.0000e+00, 2.6989e-22, 0.0000e+00, 0.0000e+00])


In [11]:
def stable_softmax2(x):
    """returns softmax result and log sum exp"""
    m = x.max()
    a = (x - m).exp()
    b = a.sum()
    lse = m + torch.log(b)
    return a / b, lse

Again we can now use this to combine two softmax block results, but to do it in the same way as before we would need to calculate the exp() values.. which is as we know numerically not stable:

In [12]:
x = torch.randn(20)

a = torch.softmax(x, dim=-1)

x1, x2 = x.chunk(2)

b1, lse1 = stable_softmax2(x1)
b2, lse2 = stable_softmax2(x2)

c1 = b1 * torch.exp(lse1) / (torch.exp(lse1) + torch.exp(lse2))
c2 = b2 * torch.exp(lse2) / (torch.exp(lse1) + torch.exp(lse2))

print(a)
print(torch.cat([c1, c2]), torch.allclose(a, torch.cat([c1, c2])))


tensor([0.0140, 0.0597, 0.0266, 0.0128, 0.0549, 0.0315, 0.0457, 0.0235, 0.0570,
        0.0340, 0.1228, 0.0127, 0.0575, 0.0276, 0.0136, 0.0049, 0.1780, 0.0686,
        0.0300, 0.1246])
tensor([0.0140, 0.0597, 0.0266, 0.0128, 0.0549, 0.0315, 0.0457, 0.0235, 0.0570,
        0.0340, 0.1228, 0.0127, 0.0575, 0.0276, 0.0136, 0.0049, 0.1780, 0.0686,
        0.0300, 0.1246]) True


But luckily we can rewrite it (`a/(a+b) = 1/(1 + b/a)`) and replace the exp-division by a sbtration of log-values (`exp(a)/exp(b) = exp(a-b)`).


这样做的好处是使用对数操作来减少数值溢出，提高稳定性。
合并后的结果b与完整计算得到的结果a进行比较，使用torch.allclose()函数验证，结果为True，表示数值稳定的分块合并策略成功达到了与整体计算一致的结果。
旁边解释了一个数学技巧：
。 提到要在对数尺度上进行减法而非除法，从而保证数值稳定性。

In [13]:
d1 = b1 / (1 + torch.exp(lse2 - lse1))
d2 = b2 / (1 + torch.exp(lse1 - lse2))
print(a)
print(torch.cat([d1, d2]))
print("allclose: ", torch.allclose(a, torch.cat([d1, d2])))

tensor([0.0140, 0.0597, 0.0266, 0.0128, 0.0549, 0.0315, 0.0457, 0.0235, 0.0570,
        0.0340, 0.1228, 0.0127, 0.0575, 0.0276, 0.0136, 0.0049, 0.1780, 0.0686,
        0.0300, 0.1246])
tensor([0.0140, 0.0597, 0.0266, 0.0128, 0.0549, 0.0315, 0.0457, 0.0235, 0.0570,
        0.0340, 0.1228, 0.0127, 0.0575, 0.0276, 0.0136, 0.0049, 0.1780, 0.0686,
        0.0300, 0.1246])
allclose:  True


With the fresh knowledge about softmax we can now take a look at the `update()` function that is used in the ring-flash-attention implementation:

In [14]:
def _update_out_and_lse(
    out: torch.Tensor,
    lse: torch.Tensor,
    block_out: torch.Tensor,
    block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    block_out = block_out.to(torch.float32)
    block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

    new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
    out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out

    lse = new_lse
    return out, lse