In [1]:
from xformers.ops.fmha import (
    memory_efficient_attention_forward,
    memory_efficient_attention_backward, 
    memory_efficient_attention_partial,
    merge_attentions
)
import torch

In [2]:
q = torch.randn(1, 21, 2, 128).cuda().to(torch.bfloat16)
k = torch.randn(1, 21, 100, 128).cuda().to(torch.bfloat16)
v = torch.randn(1, 21, 100, 128).cuda().to(torch.bfloat16)

In [3]:
%%time

out_dot = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out_dot.shape

CPU times: user 2.01 ms, sys: 953 µs, total: 2.97 ms
Wall time: 2.41 ms


torch.Size([1, 21, 2, 128])

In [4]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
v_.shape

torch.Size([1, 100, 21, 128])

In [5]:
partition_size = 5 # means we split attention into 5 partitions and incrementally calculate it
k_chunks = k_.chunk(partition_size, dim = 1)
v_chunks = v_.chunk(partition_size, dim = 1)

In [6]:
%%time

# https://github.com/ScalingIntelligence/hydragen/blob/main/hydragen/attention.py#L21
outs, lses = [], []
for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    outs.append(out_)
    lses.append(lse_)
    
outs = torch.stack(outs)
lses = torch.stack(lses)

max_lse = lses.max(0).values

adjust_factors = (lses - max_lse[None]).exp()
adjust_factors = adjust_factors.transpose(2, 3).unsqueeze(-1)
new_denominator = adjust_factors.sum(0)
out_offload = ((outs * adjust_factors).sum(0) / new_denominator).transpose(1, 2)
out_offload.shape

CPU times: user 18.6 ms, sys: 23.1 ms, total: 41.6 ms
Wall time: 40.9 ms


torch.Size([1, 21, 2, 128])

In [7]:
(out_offload - out_dot) # the difference should be super small

tensor([[[[-2.1083e-04,  4.6998e-04, -2.2201e-04,  ...,  1.3250e-04,
           -9.2015e-05, -2.4086e-04],
          [ 4.1917e-05,  7.1779e-04, -4.3038e-04,  ..., -3.1373e-04,
            5.4985e-06,  4.5154e-04]],

         [[ 3.6941e-04, -1.3298e-03,  8.3983e-05,  ...,  6.5812e-04,
           -3.7771e-04,  1.7650e-05],
          [ 7.5035e-04,  7.2541e-05,  7.7300e-05,  ..., -8.0094e-04,
            3.4310e-04, -3.7353e-04]],

         [[ 1.7822e-05,  6.1104e-04, -1.4340e-04,  ...,  8.2256e-04,
           -5.3018e-04,  3.2209e-04],
          [ 3.9487e-04, -1.3969e-04,  3.4460e-04,  ...,  1.7525e-04,
            3.3581e-04,  6.3409e-04]],

         ...,

         [[-5.8189e-04, -4.5300e-05,  3.0844e-04,  ..., -2.3700e-04,
            2.8881e-04, -3.2416e-04],
          [ 4.6875e-05,  5.4127e-04,  1.7814e-04,  ...,  4.4993e-04,
            4.5493e-05, -7.7967e-04]],

         [[ 3.9813e-04, -1.9488e-04,  4.2811e-04,  ..., -3.5535e-04,
            1.0266e-03, -1.2019e-04],
          [-2.

In [8]:
outs, max_lse = None, None
new_denominator = None
attn_output = None

for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    lse_ = lse_.transpose(1, 2)

    if i == 0:
        max_lse = lse_
        adjust_factors = torch.ones_like(lse_).unsqueeze(-1)
        new_denominator = adjust_factors
        attn_output = out_ * adjust_factors
    else:
        new_max_lse = torch.maximum(max_lse, lse_)
        
        old_adjust_factors = torch.exp(max_lse - new_max_lse).unsqueeze(-1)
        new_adjust_factors = torch.exp(lse_ - new_max_lse).unsqueeze(-1)
        
        new_denominator = old_adjust_factors * new_denominator + new_adjust_factors
        attn_output = old_adjust_factors * attn_output + new_adjust_factors * out_
        
        max_lse = new_max_lse

attn_output = attn_output / new_denominator
attn_output = attn_output.transpose(1, 2)

In [9]:
attn_output - out_offload

tensor([[[[-1.3039e-08,  0.0000e+00, -1.4901e-08,  ...,  1.4901e-08,
           -7.4506e-09,  0.0000e+00],
          [ 0.0000e+00,  1.4901e-08,  0.0000e+00,  ...,  2.2352e-08,
            3.7253e-09,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            3.7253e-09,  0.0000e+00],
          [ 0.0000e+00,  3.7253e-09,  7.4506e-09,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00, -1.8626e-09,  ..., -1.4901e-08,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  9.7789e-09,  0.0000e+00,  ...,  7.4506e-09,
           -1.1176e-08, -1.4901e-08]],

         ...,

         [[ 2.9802e-08,  0.0000e+00,  0.0000e+00,  ...,  7.4506e-09,
            0.0000e+00,  0.0000e+00],
          [ 1.8626e-08, -7.4506e-09, -1.4901e-08,  ...,  0.0000e+00,
           -1.4901e-08,  0.0000e+00]],

         [[ 1.4901e-08, -1.1176e-08, -2.9802e-08,  ...,  1.4901e-08,
           -9.3132e-09,  1.4901e-08],
          [-1.