In [1]:
import torch
import numpy as np

import torch

from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch import Tensor
import time

import torch.nn.functional as F
import numpy as np

from mask import create_full_arrow_mask, create_half_arrow_mask, create_headwise_window_mask

In [2]:
flex_attention = torch.compile(flex_attention)

FLEX_MASK_CACHE = {}
def flex_headwin_attn(q: Tensor, k: Tensor, v: Tensor, n_im_tokens: int, n_text_tokens: int, spatial_mask, block_size=(128, 128)):
    cache_key = (f"{n_im_tokens}_{n_text_tokens}_{block_size}", spatial_mask) 
    if cache_key in FLEX_MASK_CACHE:
        block_mask = FLEX_MASK_CACHE[cache_key]
    else:
        def mask_function(b, h, q_idx, kv_idx):
            cond1 = spatial_mask[h, q_idx, kv_idx]
            return cond1

        total_tokens = n_im_tokens + n_text_tokens

        block_mask = create_block_mask(
            mask_function,
            B=None,
            H=q.shape[1],
            Q_LEN=total_tokens,
            KV_LEN=total_tokens,
            BLOCK_SIZE=block_size,
        )
        FLEX_MASK_CACHE[cache_key] = block_mask 

    output = flex_attention(q, k, v, block_mask=block_mask)
    return output

In [3]:
batch = 8

seqlen_vision = 4096
seqlen_text = 333
num_head = 24

head_dim = 64

B, N, H, D = batch, seqlen_vision + seqlen_text, num_head, head_dim  
device = torch.device("cuda")
block_size = (128, 128)

In [4]:
window_sizes = torch.zeros((H, 2), device=device, dtype=torch.int32)
window_sizes[0] = torch.tensor((13, 16), device=device, dtype=torch.int32)
window_sizes[1:4] = torch.tensor((13, 15), device=device, dtype=torch.int32)
window_sizes[4:9] = torch.tensor((14, 16), device=device, dtype=torch.int32)
window_sizes[9:] = torch.tensor((14, 14), device=device, dtype=torch.int32)

In [5]:
q, k, v = torch.randn(B * 3, N, H, D, dtype=torch.float16, device=device).split(B, dim=0)

In [6]:
from flash_attn_original import flash_attn_func as flash_attn_func_original

ori_full_times = []

for _ in range(100):
    flash_attn_func_original(q, k, v, window_size=(-1 , -1))
torch.cuda.synchronize()

start_ori = torch.cuda.Event(True)
end_ori =  torch.cuda.Event(True)

for _ in range(100): 
    start_ori.record()
    o_ori = flash_attn_func_original(q, k, v, window_size=(-1 , -1))
    end_ori.record()
    torch.cuda.synchronize()
    tot_original = (start_ori.elapsed_time(end_ori)) 
    ori_full_times.append(tot_original)
ori_full_mean = np.mean(np.array(ori_full_times))

In [7]:
from flash_attn_ours import  headwise_arrow_attn, headwise_half_arrow_attn
######################################################
#       test headwise full arrow attention time
######################################################
hw_fa_times = []
flex_full_arrow_times = []
output_hw_full_arrows_rights = []
o_ori_vs_o_fa = []

for _ in range(100):
    headwise_arrow_attn(q, k, v, window_sizes=window_sizes, seqlen_q_vision = seqlen_vision, seqlen_k_vision = seqlen_vision)
torch.cuda.synchronize()

start_hw_full_arrow = torch.cuda.Event(True)
end_hw_full_arrow =  torch.cuda.Event(True)

        
for _ in range(100): 
    start_hw_full_arrow.record()
    o_hw_full_arrow = headwise_arrow_attn(q, k, v, window_sizes=window_sizes, seqlen_q_vision = seqlen_vision, seqlen_k_vision = seqlen_vision)
    end_hw_full_arrow.record()
    torch.cuda.synchronize()
    tot_hw_fa = (start_hw_full_arrow.elapsed_time(end_hw_full_arrow)) 
    hw_fa_times.append(tot_hw_fa)
hw_full_arrow_mean = np.mean(np.array(hw_fa_times))

window_masks, _ = create_full_arrow_mask(H, seqlen_vision, seqlen_text, window_sizes, (128, 128))
flex_masks = torch.zeros(window_masks.shape[0], 
                        (N + block_size[0] - 1) // block_size[0] * block_size[0], 
                        (N + block_size[1] - 1) // block_size[1] * block_size[1],
                        dtype=window_masks.dtype,
                        device=window_masks.device)
flex_masks[:, :N, :N] = window_masks

qp, kp, vp = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)

for _ in range(100):
    flex_headwin_attn(qp,
                            kp,
                            vp, 
                            seqlen_vision, seqlen_text, flex_masks, (128, 128))
    torch.cuda.synchronize()

start_flex_fa = torch.cuda.Event(True)
end_flex_fa = torch.cuda.Event(True)
for _ in range(100): 
    start_flex_fa.record()
    o_flex_full_arrow = flex_headwin_attn(qp,
                            kp,
                            vp,
                            seqlen_vision, seqlen_text, flex_masks, (128, 128))
    
    end_flex_fa.record()
    torch.cuda.synchronize()
    tot_flex_full_arrow = (start_flex_fa.elapsed_time(end_flex_fa)) 
    flex_full_arrow_times.append(tot_flex_full_arrow)
flex_full_arrow_mean = np.mean(np.array(flex_full_arrow_times))
print(torch.allclose(o_flex_full_arrow.permute(0, 2, 1, 3), o_hw_full_arrow, 1e-3, 1e-3))

True


In [8]:
hw_half_arrow_times, hw_half_arrow_mean = [], []
flex_half_arrow_times = []
output_hw_half_arrows_rights = []
o_ori_vs_o_ha = []

for _ in range(100):
    headwise_half_arrow_attn(q, k, v, window_sizes=window_sizes, seqlen_q_vision = seqlen_vision, seqlen_k_vision = seqlen_vision)
torch.cuda.synchronize()

start_hw_ha = torch.cuda.Event(True)
end_hw_ha =  torch.cuda.Event(True)


for _ in range(100): 
    start_hw_ha.record()
    o_hw_half_arrow = headwise_half_arrow_attn(q, k, v, window_sizes=window_sizes, seqlen_q_vision = seqlen_vision, seqlen_k_vision = seqlen_vision)
    end_hw_ha.record()
    torch.cuda.synchronize()
    tot_time_ha = (start_hw_ha.elapsed_time(end_hw_ha)) 
    hw_half_arrow_times.append(tot_time_ha)
hw_half_arrow_mean = np.mean(np.array(hw_half_arrow_times))

window_masks_ha, _ = create_half_arrow_mask(H, seqlen_vision, seqlen_text, window_sizes, (128, 128))
flex_masks_ha = torch.zeros(window_masks.shape[0], 
                        (N + block_size[0] - 1) // block_size[0] * block_size[0], 
                        (N + block_size[1] - 1) // block_size[1] * block_size[1],
                        dtype=window_masks_ha.dtype,
                        device=window_masks_ha.device)
flex_masks_ha[:, :N, :N] = window_masks_ha

qp, kp, vp = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)

for _ in range(100):
    flex_headwin_attn(qp,
                            kp,
                            vp, 
                            seqlen_vision, seqlen_text, flex_masks_ha, (128, 128))
    torch.cuda.synchronize()


start_flex_ha = torch.cuda.Event(True)
end_flex_ha = torch.cuda.Event(True)
for _ in range(100): 
    start_flex_ha.record()
    o_flex_half_arrow = flex_headwin_attn(qp,
                            kp,
                            vp,
                            seqlen_vision, seqlen_text, flex_masks_ha, (128, 128))

    end_flex_ha.record()
    torch.cuda.synchronize()
    tot_flex_half_arrow = (start_flex_ha.elapsed_time(end_flex_ha)) 
    flex_half_arrow_times.append(tot_flex_half_arrow)
flex_half_arrow_mean = np.mean(np.array(flex_half_arrow_times))
print(torch.allclose(o_flex_half_arrow.permute(0, 2, 1, 3), o_hw_half_arrow, 1e-3, 1e-3))

True


In [9]:
print(f"original time: {ori_full_mean}")
print(f"flex full arrow time: {flex_full_arrow_mean}")
print(f"ours full arrow time: {hw_full_arrow_mean}")

print(f"flex half arrow time: {flex_half_arrow_mean}")
print(f"ours half arrow time: {hw_half_arrow_mean}")

original time: 5.096377291679382
flex full arrow time: 5.218748784065246
ours full arrow time: 1.6892707216739655
flex half arrow time: 2.55883807182312
ours half arrow time: 1.2082035231590271
