In [1]:
import torch
import torch.nn.functional as F
from xformers.ops.fmha import (
    memory_efficient_attention_forward,
    memory_efficient_attention_backward, 
    memory_efficient_attention_partial,
    merge_attentions
)
import torch

In [2]:
head_num = 16
dim = 128
seq_len = 100
chunk_size = 5
batch_size = 1

In [3]:
q = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)
k = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)
v = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)

In [4]:
out_dot = torch.nn.functional.scaled_dot_product_attention(q, k, v, )

In [5]:
out_dot.shape

torch.Size([1, 16, 100, 128])

In [6]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
v_.shape

torch.Size([1, 100, 16, 128])

In [7]:
q_chunks = q_.chunk(chunk_size, dim = 1)
k_chunks = k_.chunk(chunk_size, dim = 1)
v_chunks = v_.chunk(chunk_size, dim = 1)

In [8]:
Q_block = q_chunks[0]

In [9]:
outs, max_lse = None, None
new_denominator = None
attn_output = None
new_lse_full = None

for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(Q_block, k_chunks[i], v_chunks[i])
    lse_ = lse_.transpose(1, 2)

    if i == 0:
        max_lse = lse_
        adjust_factors = torch.ones_like(lse_).unsqueeze(-1)
        new_denominator = adjust_factors
        attn_output = out_ * adjust_factors
        new_lse_full = lse_
    else:
        new_max_lse = torch.maximum(max_lse, lse_)
        
        old_adjust_factors = torch.exp(max_lse - new_max_lse).unsqueeze(-1)
        new_adjust_factors = torch.exp(lse_ - new_max_lse).unsqueeze(-1)
        
        new_denominator = old_adjust_factors * new_denominator + new_adjust_factors
        attn_output = old_adjust_factors * attn_output + new_adjust_factors * out_
        new_lse_full = new_max_lse + torch.log(torch.exp(new_lse_full - new_max_lse) + torch.exp(lse_ - new_max_lse))
        
        max_lse = new_max_lse

attn_output = attn_output / new_denominator
attn_output = attn_output.transpose(1, 2)

In [10]:
(attn_output.sign() == out_dot[:,:,:20].sign()).float().mean()

tensor(0.9993, device='cuda:0')

In [11]:
out_dot[:,:,:20].argmax(-1)

tensor([[[102, 126,   8,  62, 126,  91,  46,  87,   9, 126,  77,   9, 113, 126,
           54, 126,  42,  35,   7, 123],
         [  9,  27,  21, 104,  30,  31,  14,   9,  57,  27, 103, 124,  51,  81,
           21,  61,  30,  90,   9,  92],
         [  0,  43,  43,  43,  61,  43,  16,  43, 118,   0,   9,   9,  61, 118,
           91,   9,  43,  61,  43,   0],
         [ 13,  13,  98,  41,   2,  38,  13,  95, 124, 113,  35, 117,  98,  97,
          107,  34,  21,  41,  41,   4],
         [122,  81,  44, 109, 109,  81,  66, 103,  66, 104,  93, 101,   5,  38,
           11, 103,  66,  68,  58,  66],
         [102, 125,  50,  96,   0, 103,  96,  31,  91,  64,   5, 124, 125,   5,
            1, 103,   7, 125,  36,  64],
         [ 68,  83,  56,  99,  99, 120, 112,  45,   8,  58, 120,  99,  56, 120,
          123,  77,  92,  77, 123,  92],
         [ 85,  85,  81,  38,  79,  20,  97,  51,  78,  38,  78,   7,  77,  57,
           20,  23,  97,  51,  34,  97],
         [ 53,  58,   0, 117,  5

In [13]:
attn_output.argmax(-1)

tensor([[[102, 126,   8,  62, 126,  91,  46,  87, 113, 126,  77,   9, 113, 126,
           54, 126,  42,  35,   7, 123],
         [  9,  27,  21, 104,  30,  31,  14,   9,  57,  27, 103, 124,  51,  81,
           21,  61,  30,  90,   9,  92],
         [  0,  43,  43,  43,  61,  43,  16,  43, 118,   0,   9,   9,  61, 118,
           91,   9,  43,  61,  43,   0],
         [ 13,  13,  98,  41,   2,  38,  13,  95, 124, 113,  35, 117,  98,  97,
          107,  34,  21,  41,  41,   4],
         [122,  81,  44, 109, 109,  81,  66, 103,  66, 104,  93, 101,   5,  38,
           11, 103,  66,  68,  58,  66],
         [102, 125,  50,  96,   0, 103,  96,  31,  91,  64,   5, 124, 125,   5,
            1, 103,   7, 125,  36,  64],
         [ 68,  83,  56,  99,  99, 120, 112,  45,   8,  58, 120,  99,  56, 120,
          123,  77,  92,  77, 123,  92],
         [ 85,  85,  81,  38,  79,  20,  97,  51,  78,  81,  78,   7,  77,  57,
           20,  23,  97,  51,  34,  97],
         [ 53,  58,   0, 117,  5