In [1]:
from xformers.ops.fmha import (
    memory_efficient_attention_forward,
    memory_efficient_attention_backward, 
    memory_efficient_attention_partial,
    merge_attentions
)
import torch

In [2]:
q = torch.randn(1, 21, 2, 128).cuda().to(torch.bfloat16)
k = torch.randn(1, 21, 100, 128).cuda().to(torch.bfloat16)
v = torch.randn(1, 21, 100, 128).cuda().to(torch.bfloat16)

In [3]:
%%time

out_dot = torch.nn.functional.scaled_dot_product_attention(q, k, v)
out_dot.shape

CPU times: user 1.67 ms, sys: 4.53 ms, total: 6.2 ms
Wall time: 4.55 ms


torch.Size([1, 21, 2, 128])

In [4]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
v_.shape

torch.Size([1, 100, 21, 128])

In [5]:
partition_size = 5 # means we split attention into 5 partitions and incrementally calculate it
k_chunks = k_.chunk(partition_size, dim = 1)
v_chunks = v_.chunk(partition_size, dim = 1)

In [6]:
%%time

# https://github.com/ScalingIntelligence/hydragen/blob/main/hydragen/attention.py#L21
outs, lses = [], []
for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(q_, k_chunks[i], v_chunks[i])
    outs.append(out_)
    lses.append(lse_)
    
outs = torch.stack(outs)
lses = torch.stack(lses)

max_lse = lses.max(0).values

adjust_factors = (lses - max_lse[None]).exp()
adjust_factors = adjust_factors.transpose(2, 3).unsqueeze(-1)
new_denominator = adjust_factors.sum(0)
out_offload = ((outs * adjust_factors).sum(0) / new_denominator).transpose(1, 2)
out_offload.shape

CPU times: user 24.8 ms, sys: 44.6 ms, total: 69.4 ms
Wall time: 67.7 ms


torch.Size([1, 21, 2, 128])

In [7]:
(out_offload - out_dot) # the difference should be super small

tensor([[[[ 1.2345e-03, -2.4971e-04, -6.6012e-04,  ..., -2.1459e-04,
            3.7201e-04, -1.3451e-04],
          [-7.3791e-05, -3.8874e-04, -3.0696e-04,  ..., -5.2631e-05,
            6.5991e-04, -8.1521e-04]],

         [[ 3.0283e-04, -8.0040e-04,  3.8594e-04,  ...,  1.2854e-04,
            3.9496e-04,  7.4293e-05],
          [-1.1277e-04,  5.5200e-04, -6.7055e-04,  ..., -2.2515e-04,
           -6.5799e-04,  3.8579e-05]],

         [[-5.3185e-04,  4.1805e-05, -7.8999e-05,  ..., -1.3923e-04,
            3.5194e-04, -2.9349e-04],
          [-1.7911e-04,  4.5266e-04, -3.9779e-04,  ..., -1.9804e-05,
            2.9813e-04, -7.9170e-05]],

         ...,

         [[-8.3151e-04,  1.2643e-04,  1.1572e-03,  ..., -1.8989e-04,
           -8.5682e-08,  9.7290e-05],
          [ 6.8823e-04,  9.1299e-04,  1.8278e-04,  ...,  2.0218e-04,
            8.0469e-04,  1.2076e-04]],

         [[-2.3107e-04,  7.3041e-04, -3.1593e-04,  ..., -6.9819e-05,
            3.9328e-05, -5.7943e-04],
          [ 2.