In [13]:
import torch
import math

In [10]:
def softmax_naive(x: torch.Tensor):
    m = x.max()
    f = (x - m).exp()
    l = f.sum()
    softmax = f / l
    return softmax

a = torch.rand((15,))
torch.softmax(a, dim=0), softmax_naive(a)

(tensor([0.0710, 0.0595, 0.0612, 0.0955, 0.0736, 0.0475, 0.0742, 0.0763, 0.0770,
         0.0369, 0.0896, 0.0751, 0.0520, 0.0682, 0.0424]),
 tensor([0.0710, 0.0595, 0.0612, 0.0955, 0.0736, 0.0475, 0.0742, 0.0763, 0.0770,
         0.0369, 0.0896, 0.0751, 0.0520, 0.0682, 0.0424]))

In [42]:

def softmax_tiled(x: torch.Tensor, k=3):
    n_blocks = math.ceil(x.size()[0] / k)

    max_acc = float("-inf")
    f_values = torch.zeros((x.size()[0],))
    l_values = torch.zeros((n_blocks,))
    m_values = torch.zeros((n_blocks,))

    for b in range(0, n_blocks):
        x_b = x[b * k:b * k + k]

        m_b = x_b.max()

        f_b = (x_b - m_b).exp()
        l_b = f_b.sum()

        f_values[b * k:b * k + k] = f_b
        l_values[b] = l_b
        m_values[b] = m_b

        max_acc = max(max_acc, m_b)

    m = max_acc
    scaling_factors = (m_values - m).exp()
    l = (scaling_factors * l_values).sum()

    results = torch.zeros((x.size()[0],))

    for b in range(0, n_blocks):
        scaling_factor = scaling_factors[b]

        f_b = f_values[b * k:b * k + k]
        res_b = (f_b * scaling_factor) / l

        results[b * k:b * k + k] = res_b

    return results

a = torch.rand((15,))
torch.softmax(a, dim=0), softmax_tiled(a, k=3)
# softmax_tiled(a).sum()

(tensor([0.0440, 0.0609, 0.0417, 0.0654, 0.0873, 0.0994, 0.0434, 0.0647, 0.0428,
         0.0982, 0.0640, 0.0617, 0.0536, 0.0838, 0.0891]),
 tensor([0.0440, 0.0609, 0.0417, 0.0654, 0.0873, 0.0994, 0.0434, 0.0647, 0.0428,
         0.0982, 0.0640, 0.0617, 0.0536, 0.0838, 0.0891]))

In [None]:

def softmax_tiled_oneloop(x: torch.Tensor, k=3):
    n_blocks = math.ceil(x.size()[0] / k)

    max_acc = float("-inf")
    f_values = torch.zeros((x.size()[0],))
    l_values = torch.zeros((n_blocks,))
    m_values = torch.zeros((n_blocks,))

    for b in range(0, n_blocks):
        x_b = x[b * k:b * k + k]

        m_b = x_b.max()

        f_b = (x_b - m_b).exp()
        l_b = f_b.sum()

        f_values[b * k:b * k + k] = f_b
        l_values[b] = l_b
        m_values[b] = m_b

        max_acc = max(max_acc, m_b)

    m = max_acc
    scaling_factors = (m_values - m).exp()
    l = (scaling_factors * l_values).sum()

    results = torch.zeros((x.size()[0],))

    for b in range(0, n_blocks):
        scaling_factor = scaling_factors[b]

        f_b = f_values[b * k:b * k + k]
        res_b = (f_b * scaling_factor) / l

        results[b * k:b * k + k] = res_b

    return results

a = torch.rand((15,))
torch.softmax(a, dim=0), softmax_tiled(a, k=3)
# softmax_tiled(a).sum()

In [54]:
def FlashAttention(Q, K, V, B=10):
    N, d = Q.shape

    # Initialize O, l, m
    O = torch.zeros_like(Q)
    l = torch.zeros(N, dtype=Q.dtype)
    m = torch.full((N,), float("-inf"), dtype=Q.dtype)

    # Determine the number of blocks
    Tr = math.ceil(N / B)
    Tc = Tr

    # For each column block
    for j in range(Tc):
        # Load Kj, Vj
        Kj = K[j*B:(j+1)*B, :]
        Vj = V[j*B:(j+1)*B, :]

        # For each row block
        for i in range(Tr):
            # Load Qi, Oi, li, mi
            Qi = Q[i*B:(i+1)*B, :]
            Oi = O[i*B:(i+1)*B, :]
            li = l[i*B:(i+1)*B]
            mi = m[i*B:(i+1)*B]

            # On chip, compute Sij, ~mij, ~Pij, ~lij
            Sij = Qi @ Kj.T
            mij = Sij.max(dim=1)[0]
            Pij = torch.exp(Sij - mij.unsqueeze(1))
            lij = Pij.sum(dim=1)

            # On chip, compute mnewi, lnewi
            mnewi = torch.max(mi, mij)
            lnewi = torch.exp(mi - mnewi) * li + torch.exp(mij - mnewi) * lij

            # Write Oi, li, mi back to HBM
            Oi = torch.diag(1. / lnewi) @ (torch.diag(torch.exp(mi - mnewi) * li) @ Oi + torch.exp(mij - mnewi) * Pij @ Vj)
            O[i*B:(i+1)*B, :] = Oi
            l[i*B:(i+1)*B] = lnewi
            m[i*B:(i+1)*B] = mnewi

    # Return O
    return O

N, d = 20, 10
Q = torch.randn(N, d, dtype=torch.float64)
K = torch.randn(N, d, dtype=torch.float64)
V = torch.randn(N, d, dtype=torch.float64)

torch.abs(torch.softmax(Q @ K.T, dim=-1) @ V - FlashAttention(Q, K, V))

tensor([[5.2988e-04, 3.4815e-03, 8.0749e-04, 3.2067e-03, 2.8264e-03, 2.9001e-03,
         3.1164e-03, 6.2981e-03, 4.6885e-03, 1.5142e-03],
        [3.4915e-03, 1.6627e-01, 1.0068e-02, 8.9618e-02, 3.1834e-02, 5.2418e-02,
         5.2547e-02, 2.5292e-01, 7.5699e-02, 1.7230e-02],
        [2.9257e-02, 1.0018e-01, 5.5842e-02, 1.1918e-01, 1.8103e-02, 1.7371e-02,
         3.8449e-02, 1.8871e-01, 7.7059e-02, 6.8165e-02],
        [2.7490e-03, 9.6256e-02, 2.4736e-01, 4.3178e-02, 1.4328e-01, 1.1132e-03,
         8.9672e-03, 3.4262e-01, 4.8426e-01, 4.4858e-01],
        [1.0496e-01, 4.5636e-01, 6.7951e-01, 9.2824e-01, 5.7093e-01, 2.7699e-01,
         3.5616e-01, 8.3383e-02, 7.9569e-01, 1.0648e+00],
        [1.0492e-01, 1.4487e-01, 1.9831e-01, 2.5742e-03, 1.7217e-01, 2.8173e-01,
         1.5352e-01, 1.9702e-01, 1.7888e-01, 5.0681e-02],
        [2.6330e-01, 1.2877e+00, 2.8098e-01, 7.2526e-01, 9.0763e-01, 3.5823e-02,
         1.2726e+00, 2.0113e+00, 1.0532e+00, 1.0247e+00],
        [2.7644e-04, 1.0646