In [1]:
from xformers.ops.fmha import (
    memory_efficient_attention_forward,
    memory_efficient_attention_backward, 
    memory_efficient_attention_partial,
    merge_attentions
)
import torch

In [2]:
q = torch.randn(1, 21, 1, 128).cuda().to(torch.bfloat16)
k = torch.randn(1, 21, 100, 128).cuda().to(torch.bfloat16)
v = torch.randn(1, 21, 100, 128).cuda().to(torch.bfloat16)

In [3]:
%%time

out_dot = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out_dot.shape

CPU times: user 2.03 ms, sys: 945 µs, total: 2.98 ms
Wall time: 2.35 ms


torch.Size([1, 21, 1, 128])

In [4]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
v_.shape

torch.Size([1, 100, 21, 128])

In [5]:
partition_size = 5 # means we split attention into 5 partitions and incrementally calculate it
k_chunks = k_.chunk(partition_size, dim = 1)
v_chunks = v_.chunk(partition_size, dim = 1)

In [6]:
q_.shape

torch.Size([1, 1, 21, 128])

In [7]:
%%time

# https://github.com/ScalingIntelligence/hydragen/blob/main/hydragen/attention.py#L21
outs, lses = [], []
for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    outs.append(out_)
    lses.append(lse_)
    
outs = torch.stack(outs)
lses = torch.stack(lses)

max_lse = lses.max(0).values
lse_full = (max_lse + (lses - max_lse[None]).exp().sum(dim=0).log())

adjust_factors = (lses - max_lse[None]).exp()
adjust_factors = adjust_factors.transpose(2, 3).unsqueeze(-1)
new_denominator = adjust_factors.sum(0)
out_offload = ((outs * adjust_factors).sum(0) / new_denominator).transpose(1, 2)
out_offload.shape

CPU times: user 2.4 ms, sys: 35.2 ms, total: 37.6 ms
Wall time: 37.3 ms


torch.Size([1, 21, 1, 128])

In [8]:
out_offload

tensor([[[[ 0.3315,  0.1697, -0.0938,  ...,  0.1379,  0.1604, -0.1074]],

         [[-0.0157,  0.0269, -0.1003,  ...,  0.1017, -0.1896,  0.1197]],

         [[-0.1364,  0.3072,  0.1040,  ..., -0.0357,  0.0404, -0.0206]],

         ...,

         [[-0.1383, -0.0318, -0.0719,  ...,  0.0949,  0.0897,  0.0715]],

         [[ 0.0010,  0.0593, -0.1481,  ..., -0.0633,  0.1171,  0.0734]],

         [[-0.1108,  0.1146, -0.0208,  ..., -0.2461,  0.0471,  0.0960]]]],
       device='cuda:0')

In [9]:
(out_offload - out_dot) # the difference should be super small

tensor([[[[-4.9028e-04,  7.9504e-04, -5.3100e-05,  ...,  2.3939e-04,
            2.0193e-04,  4.1224e-05]],

         [[ 6.1432e-05,  8.3534e-05,  2.5561e-04,  ..., -3.0986e-04,
            7.8350e-04,  6.0281e-04]],

         [[ 3.0224e-04,  5.7122e-04, -5.0990e-04,  ...,  1.9699e-04,
           -1.5710e-04, -4.4663e-04]],

         ...,

         [[-6.0607e-04, -4.2256e-05, -8.3089e-05,  ..., -3.2692e-04,
           -1.2110e-04,  7.2782e-04]],

         [[ 1.8294e-04,  1.7340e-04,  3.7739e-04,  ...,  2.1625e-04,
           -9.7543e-05, -3.3285e-04]],

         [[ 7.0259e-05, -1.2022e-04,  1.6701e-04,  ..., -9.4342e-04,
            4.4819e-04, -1.8079e-04]]]], device='cuda:0')

In [10]:
outs, max_lse = None, None
new_denominator = None
attn_output = None
new_lse_full = None

for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    lse_ = lse_.transpose(1, 2)

    if i == 0:
        max_lse = lse_
        adjust_factors = torch.ones_like(lse_).unsqueeze(-1)
        new_denominator = adjust_factors
        attn_output = out_ * adjust_factors
        new_lse_full = lse_
    else:
        new_max_lse = torch.maximum(max_lse, lse_)
        
        old_adjust_factors = torch.exp(max_lse - new_max_lse).unsqueeze(-1)
        new_adjust_factors = torch.exp(lse_ - new_max_lse).unsqueeze(-1)
        
        new_denominator = old_adjust_factors * new_denominator + new_adjust_factors
        attn_output = old_adjust_factors * attn_output + new_adjust_factors * out_
        new_lse_full = new_max_lse + torch.log(torch.exp(new_lse_full - new_max_lse) + torch.exp(lse_ - new_max_lse))
        
        max_lse = new_max_lse

attn_output = attn_output / new_denominator
attn_output = attn_output.transpose(1, 2)

In [11]:
attn_output - out_offload

tensor([[[[-2.9802e-08,  2.9802e-08,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  7.4506e-09]],

         [[ 1.3039e-08, -1.8626e-09,  0.0000e+00,  ...,  7.4506e-09,
            1.4901e-08,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            3.7253e-09,  0.0000e+00]],

         ...,

         [[ 0.0000e+00,  3.7253e-09,  0.0000e+00,  ..., -7.4506e-09,
           -7.4506e-09,  7.4506e-09]],

         [[-1.8626e-09,  3.7253e-09,  0.0000e+00,  ..., -7.4506e-09,
           -7.4506e-09,  1.4901e-08]],

         [[ 1.4901e-08,  7.4506e-09, -7.4506e-09,  ...,  0.0000e+00,
            7.4506e-09,  0.0000e+00]]]], device='cuda:0')

In [12]:
lse_full

tensor([[[5.0803],
         [5.1817],
         [5.2088],
         [4.8658],
         [4.9444],
         [4.9638],
         [5.4463],
         [5.1694],
         [5.2293],
         [5.3548],
         [5.1799],
         [4.9810],
         [5.1273],
         [5.2139],
         [5.1583],
         [5.1400],
         [4.9378],
         [5.0265],
         [5.1528],
         [5.2309],
         [5.0253]]], device='cuda:0')

In [13]:
lse_full.grad