# How-to Log Sum Exp

flash-attention  的softmax

In [1]:

## 简单的softmax
import torch

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    return x.exp() / x.exp().sum()

In [2]:
## 验证输出

x = torch.randn(10)  # generate normally distributed random numbers
a = torch.softmax(x, dim=-1) # reference output
b = naive_softmax(x) # our naive version

print("a", a)
print("b", b)
print("allclose", torch.allclose(a, b, atol=1e-6))

a tensor([0.1631, 0.0121, 0.2070, 0.0681, 0.0422, 0.0418, 0.2359, 0.1124, 0.0797,
        0.0379])
b tensor([0.1631, 0.0121, 0.2070, 0.0681, 0.0422, 0.0418, 0.2359, 0.1124, 0.0797,
        0.0379])
allclose True


In [3]:
# 当数值较大时，存在输出不稳定
naive_softmax(x * 100)

tensor([nan, 0., nan, 0., 0., 0., nan, 0., 0., 0.])

In [4]:
# 将向量切分为两个小的均等的块进行计算

x = torch.randn(10)

x1,x2 = torch.chunk(x, 2)
s1 = naive_softmax(x1)
s2 = naive_softmax(x2)

print("We have:")
print(f"s1 = {s1}")
print(f"s2 = {s2}")

target = naive_softmax(x)
print("We want:")
print(f"target = {target}")

We have:
s1 = tensor([0.0409, 0.4080, 0.2601, 0.1996, 0.0914])
s2 = tensor([0.6396, 0.1177, 0.0909, 0.0431, 0.1087])
We want:
target = tensor([0.0269, 0.2687, 0.1713, 0.1314, 0.0602, 0.2184, 0.0402, 0.0310, 0.0147,
        0.0371])


In [None]:

## 合并,可以将 sum exp 称之为 log sum exp

sum_exp_x1 = x1.exp().sum()
sum_exp_x2 = x2.exp().sum()
s1_corrected = s1 * sum_exp_x1 / (sum_exp_x1 + sum_exp_x2)
s2_corrected = s2 * sum_exp_x2 / (sum_exp_x1 + sum_exp_x2)

print("After correction with help of sum_exp values:")
s_combined = torch.cat([s1_corrected, s2_corrected])
print("s_combined", s_combined)

print("allclose(s_combined, target):", torch.allclose(s_combined, target))

After correction with help of sum_exp values:
s_combined tensor([0.0269, 0.2687, 0.1713, 0.1314, 0.0602, 0.2184, 0.0402, 0.0310, 0.0147,
        0.0371])
allclose(s_combined, target): True
