In [2]:
from xformers.ops.fmha import (
    memory_efficient_attention_forward,
    memory_efficient_attention_backward, 
    memory_efficient_attention_partial,
    merge_attentions
)
import torch

In [3]:
head_num = 16
dim = 128
seq_len = 100
chunk_size = 5
batch_size = 1

In [4]:
q = torch.randn(batch_size, head_num, 1, dim).cuda().to(torch.bfloat16)
k = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)
v = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)

In [5]:
%%time

out_dot = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out_dot.shape

CPU times: user 0 ns, sys: 10 ms, total: 10 ms
Wall time: 28.7 ms


torch.Size([1, 16, 1, 128])

In [6]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
v_.shape

torch.Size([1, 100, 16, 128])

In [8]:
# means we split attention into 5 partitions and incrementally calculate it
k_chunks = k_.chunk(chunk_size, dim = 1)
v_chunks = v_.chunk(chunk_size, dim = 1)

In [9]:
q_.shape

torch.Size([1, 1, 16, 128])

In [10]:
%%time

# https://github.com/ScalingIntelligence/hydragen/blob/main/hydragen/attention.py#L21
outs, lses = [], []
for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    outs.append(out_)
    lses.append(lse_)
    
outs = torch.stack(outs)
lses = torch.stack(lses)

max_lse = lses.max(0).values
lse_full = (max_lse + (lses - max_lse[None]).exp().sum(dim=0).log())

adjust_factors = (lses - max_lse[None]).exp()
adjust_factors = adjust_factors.transpose(2, 3).unsqueeze(-1)
new_denominator = adjust_factors.sum(0)
out_offload = ((outs * adjust_factors).sum(0) / new_denominator).transpose(1, 2)
out_offload.shape

CPU times: user 16.3 ms, sys: 42.2 ms, total: 58.5 ms
Wall time: 71.1 ms


torch.Size([1, 16, 1, 128])

In [11]:
out_offload

tensor([[[[-0.2976, -0.2209, -0.2320,  ...,  0.2711, -0.0087,  0.0849]],

         [[ 0.1269,  0.0716,  0.0772,  ...,  0.0294, -0.2892,  0.0221]],

         [[ 0.1076, -0.1437, -0.0286,  ..., -0.0479, -0.0791, -0.0275]],

         ...,

         [[-0.0288,  0.0269, -0.2649,  ...,  0.0051,  0.1333, -0.0487]],

         [[-0.0036,  0.0536,  0.0127,  ..., -0.0956, -0.2699,  0.1205]],

         [[ 0.0820, -0.2063,  0.0273,  ...,  0.0770,  0.2213, -0.1132]]]],
       device='cuda:0')

In [12]:
(out_offload - out_dot) # the difference should be super small

tensor([[[[ 1.1930e-03, -2.1507e-04,  4.6434e-04,  ..., -3.5664e-04,
            2.1090e-04, -7.1861e-05]],

         [[-6.3568e-05, -2.0878e-04,  7.6659e-05,  ..., -4.1716e-04,
           -1.4636e-04,  8.3708e-04]],

         [[ 1.8580e-04, -1.9218e-04,  3.6756e-04,  ..., -2.7718e-04,
           -4.9496e-04, -1.1262e-05]],

         ...,

         [[-2.2268e-04, -2.8858e-04,  6.8349e-04,  ..., -1.6078e-05,
           -5.1998e-04, -1.0316e-04]],

         [[-2.3383e-04, -1.4468e-04, -3.8647e-04,  ...,  1.1349e-04,
           -3.7390e-04, -1.0186e-04]],

         [[-6.3628e-05, -2.6186e-04, -4.2481e-05,  ..., -1.7399e-04,
           -4.1607e-04,  6.0721e-04]]]], device='cuda:0')

In [13]:
outs, max_lse = None, None
new_denominator = None
attn_output = None
new_lse_full = None

for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    lse_ = lse_.transpose(1, 2)

    if i == 0:
        max_lse = lse_
        adjust_factors = torch.ones_like(lse_).unsqueeze(-1)
        new_denominator = adjust_factors
        attn_output = out_ * adjust_factors
        new_lse_full = lse_
    else:
        new_max_lse = torch.maximum(max_lse, lse_)
        
        old_adjust_factors = torch.exp(max_lse - new_max_lse).unsqueeze(-1)
        new_adjust_factors = torch.exp(lse_ - new_max_lse).unsqueeze(-1)
        
        new_denominator = old_adjust_factors * new_denominator + new_adjust_factors
        attn_output = old_adjust_factors * attn_output + new_adjust_factors * out_
        new_lse_full = new_max_lse + torch.log(torch.exp(new_lse_full - new_max_lse) + torch.exp(lse_ - new_max_lse))
        
        max_lse = new_max_lse

attn_output = attn_output / new_denominator
attn_output = attn_output.transpose(1, 2)

In [14]:
attn_output - out_offload

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            3.7253e-09,  0.0000e+00]],

         [[ 1.4901e-08,  0.0000e+00,  0.0000e+00,  ...,  1.8626e-09,
            2.9802e-08,  1.8626e-09]],

         [[ 0.0000e+00, -1.4901e-08, -1.1176e-08,  ..., -1.1176e-08,
           -7.4506e-09, -1.4901e-08]],

         ...,

         [[-3.7253e-09, -1.1176e-08,  0.0000e+00,  ..., -8.3819e-09,
            0.0000e+00, -3.7253e-09]],

         [[ 7.6834e-09,  7.4506e-09, -1.9558e-08,  ...,  1.4901e-08,
            0.0000e+00,  0.0000e+00]],

         [[ 1.4901e-08, -1.4901e-08, -2.4214e-08,  ...,  1.4901e-08,
            0.0000e+00, -2.9802e-08]]]], device='cuda:0')

In [16]:
attn_output.shape

torch.Size([1, 16, 1, 128])

In [17]:
out_offload.shape

torch.Size([1, 16, 1, 128])

In [21]:
(attn_output.sign() == out_offload.sign()).float().mean()

tensor(1., device='cuda:0')