In [1]:
import torch
import flashinfer
import gc

In [2]:
def get_mask(q_length, kv_length, rank, batch_size):
    a = torch.zeros((q_length, kv_length), dtype=torch.bool)
    b = torch.ones((q_length, kv_length), dtype=torch.bool)

    # Upper
    for i in range(q_length):
        right = rank * q_length + i + 1
        a[i, :right] = True
    for i in range(q_length):
        start = kv_length - q_length * (rank+1) + i + 1
        # print(start)
        b[i, start:] = False
        pass
    # concat a, b 
    c = torch.cat([a, b], dim=0)
    # replicate c `batch_size` times
    d = torch.cat([c] * batch_size, dim=0)
    return d

In [5]:
def run_flash_attention(rank=0, batch_size=1, qo_len=128, kv_len=4096, num_qo_heads=32, num_kv_heads=32, head_dim=128, repeat=7, visualize_mask=False,device="cuda",return_tensors=False,verbose=False):
    def print_if_verbose(s):
        if verbose:
            print(s)
        return
    
    print_if_verbose(f"Running flash attention with rank {rank}, batch size {batch_size}, qo_len {qo_len}, kv_len {kv_len}, num_qo_heads {num_qo_heads}, num_kv_heads {num_kv_heads}, head_dim {head_dim}, visualize_mask {visualize_mask}, device {device}, verbose {verbose}")
    q = torch.randn(qo_len * batch_size, num_qo_heads, head_dim).half().to(device)
    k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(device)
    v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(device)
    
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    compute_times = []
    for _ in range(repeat):
        start_event.record()
        o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
        end_event.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()

        elapsed_time_ms = start_event.elapsed_time(end_event)
        compute_times.append(elapsed_time_ms)
        print_if_verbose(f"Elapsed time: {elapsed_time_ms:.2f} ms")
    
    return_values = [None, compute_times]
    if return_tensors:
        return_values[0] = o_custom
    return return_values
