In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import torch
import torch.nn.functional as F
from xformers.ops.fmha import (
    memory_efficient_attention_forward,
    memory_efficient_attention_backward, 
    memory_efficient_attention_partial,
    merge_attentions
)
import torch

In [3]:
head_num = 16
dim = 128
seq_len = 100
chunk_size = 4
batch_size = 1

In [4]:
q = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)
k = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)
v = torch.randn(batch_size, head_num, seq_len, dim).cuda().to(torch.bfloat16)

In [5]:
out_dot = torch.nn.functional.scaled_dot_product_attention(q, k, v, )

In [6]:
out_dot.shape

torch.Size([1, 16, 100, 128])

In [7]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
v_.shape

torch.Size([1, 100, 16, 128])

In [8]:
q_chunks = q_.chunk(chunk_size, dim = 1)
k_chunks = k_.chunk(chunk_size, dim = 1)
v_chunks = v_.chunk(chunk_size, dim = 1)

In [9]:
Q_block = q_chunks[0]

In [10]:
outs, max_lse = None, None
new_denominator = None
attn_output = None
new_lse_full = None

for i in range(len(k_chunks)):
    out_, lse_ = memory_efficient_attention_partial(Q_block, k_chunks[i], v_chunks[i])
    lse_ = lse_.transpose(1, 2)

    if i == 0:
        max_lse = lse_
        adjust_factors = torch.ones_like(lse_).unsqueeze(-1)
        new_denominator = adjust_factors
        attn_output = out_ * adjust_factors
        new_lse_full = lse_
    else:
        new_max_lse = torch.maximum(max_lse, lse_)
        
        old_adjust_factors = torch.exp(max_lse - new_max_lse).unsqueeze(-1)
        new_adjust_factors = torch.exp(lse_ - new_max_lse).unsqueeze(-1)
        
        new_denominator = old_adjust_factors * new_denominator + new_adjust_factors
        attn_output = old_adjust_factors * attn_output + new_adjust_factors * out_
        new_lse_full = new_max_lse + torch.log(torch.exp(new_lse_full - new_max_lse) + torch.exp(lse_ - new_max_lse))
        
        max_lse = new_max_lse

attn_output = attn_output / new_denominator
attn_output = attn_output.transpose(1, 2)

In [32]:
attn_output.shape, out_dot.shape

(torch.Size([1, 16, 25, 128]), torch.Size([1, 16, 100, 128]))

In [12]:
(attn_output.sign() == out_dot[:,:,:25].sign()).float().mean()

tensor(0.9992, device='cuda:0')

In [13]:
out_dot[:,:,:25].argmax(-1)

tensor([[[ 55, 118, 109,  94,  58,  28,  32,  11, 109, 109,  74, 118,  56,  41,
          118,  53, 109,  55, 101,  22, 114,  44,  44,  98, 114],
         [108,  47,   4,  43,  20, 115,  61,  52,  31, 106, 119, 119,  80, 108,
           20, 103,  47,  39, 108,  73, 115,  47,  47,  32, 114],
         [  3,  11,   3,  45,   3,   8,   3,   3,   3,   3,   3,   3,   3,  47,
            3,  89,   3, 111,   3,   3,   3, 126,  17,  42,   3],
         [110,   7,  32, 121,  11, 110, 112,  21,  45,  45,  21,  59,  81,  32,
           32,   8,  32,  39,  39,   8,  81,  59,   9,  81,  96],
         [ 60,  12, 127,  60, 127,  15,  12,  64,  12,  34,  78,  57, 109,  57,
           12,  57,  57,  57,  72,  15,  57,  12, 127,  72, 116],
         [ 16,  85,  32,  10,  91,  64,  32,  32,  19,  19,  16, 118,  19,  49,
           91,  19,  19,  32, 125,  35,  93,  19,  49,  19,  19],
         [ 76,  83,  76,  76, 104,  76, 104,  76,  26,  87,  14,  40, 117,  73,
           51,  84, 104,  70,  77,  87,  14,

In [14]:
attn_output.argmax(-1)

tensor([[[ 55, 118, 109,  94,  58,  28,  32,  11, 109, 109,  74, 118,  56,  41,
          118,  53, 109,  55, 101,  22, 114,  44,  44,  98, 114],
         [108,  47,   4,  43,  20, 115,  61,  52,  31, 106, 119, 119,  80, 108,
           20, 103,  47,  39, 108,  73, 115,  47,  47,  32, 114],
         [  3,  11,  55,  45,   3,   8,   3,   3,   3,   3,   3,   3,   3,  47,
            3,  89,   3, 111,   3,   3,   3, 126,  17,  42,   3],
         [110,   7,  32, 121,  11, 110, 112,  21,  45,  45,  21,  59,  81,  32,
           32,   8,  32,  39,  39,   8,  81,  59,   9,  81,  96],
         [ 60,  12, 127,  60, 127,  15,  12,  64,  12,  34,  78,  57, 109,  57,
           12,  57,  57,  57,  72,  15,  57,  12, 127,  72, 116],
         [ 16,  85,  32,  10,  91,  64,  32,  32,  19,  19,  16, 118,  19,  49,
           91,  19,  19,  32, 125,  35,  93,  19,  49,  19,  19],
         [ 76,  83,  76,  76, 104,  76, 104,  76,  26,  87,  14,  40, 117,  73,
           51,  84, 104,  70,  77,  87,  14,

In [15]:
def local_attention(Q_block, k_chunks, v_chunks):
    outs, max_lse = None, None
    new_denominator = None
    attn_output = None
    new_lse_full = None

    for i in range(len(k_chunks)):
        out_, lse_ = memory_efficient_attention_partial(Q_block, k_chunks[i], v_chunks[i])
        lse_ = lse_.transpose(1, 2)

        if i == 0:
            max_lse = lse_
            adjust_factors = torch.ones_like(lse_).unsqueeze(-1)
            new_denominator = adjust_factors
            attn_output = out_ * adjust_factors
            new_lse_full = lse_
        else:
            new_max_lse = torch.maximum(max_lse, lse_)

            old_adjust_factors = torch.exp(max_lse - new_max_lse).unsqueeze(-1)
            new_adjust_factors = torch.exp(lse_ - new_max_lse).unsqueeze(-1)

            new_denominator = old_adjust_factors * new_denominator + new_adjust_factors
            attn_output = old_adjust_factors * attn_output + new_adjust_factors * out_
            new_lse_full = new_max_lse + torch.log(torch.exp(new_lse_full - new_max_lse) + torch.exp(lse_ - new_max_lse))

            max_lse = new_max_lse
    
    return attn_output, max_lse, new_lse_full

In [17]:
attn_output_0, max_lse_0, new_lse_full_0 = local_attention(Q_block, k_chunks[:2], v_chunks[:2])
attn_output_1, max_lse_1, new_lse_full_1 = local_attention(Q_block, k_chunks[2:], v_chunks[2:])

In [20]:
denominators = []
adjusted_outputs = []

all_lses = torch.cat([lse.unsqueeze(0) for lse in max_lses], dim=0)
global_max_lse = torch.max(all_lses, dim=0)[0]

for i, (output, lse) in enumerate(zip(attn_outputs, max_lses)):
    adjust_factor = torch.exp(lse - global_max_lse).unsqueeze(-1)

    adjusted_outputs.append(output * adjust_factor)
    denominators.append(adjust_factor)

final_output = torch.zeros_like(adjusted_outputs[0])
final_denominator = torch.zeros_like(denominators[0])

for adj_output, denom in zip(adjusted_outputs, denominators):
    final_output += adj_output
    final_denominator += denom

merged_output = (final_output / final_denominator).transpose(1, 2)

In [33]:
denominators = []
adjusted_outputs = []

for i, (output, lse) in enumerate(zip(attn_outputs, max_lses)):
    adjust_factor = torch.exp(lse - global_max_lse).unsqueeze(-1)

    adjusted_outputs.append(output * adjust_factor)
    denominators.append(adjust_factor)

final_output = torch.zeros_like(adjusted_outputs[0])
final_denominator = torch.zeros_like(denominators[0])

for adj_output, denom in zip(adjusted_outputs, denominators):
    final_output += adj_output
    final_denominator += denom

merged_output = (final_output / final_denominator).transpose(1, 2)

In [41]:
(merged_output.argmax(-1) == out_dot[:,:,:25].argmax(-1)).float().mean()

tensor(0.9975, device='cuda:0')

In [36]:
(merged_output.sign() == out_dot[:,:,:25].sign()).float().mean()

tensor(0.9992, device='cuda:0')