In [1]:
import torch
import flashinfer
import gc

In [2]:
# torch.cuda.set_device('cuda:3')


In [3]:
def get_mask(q_length, kv_length, rank, batch_size):
    a = torch.tril(torch.ones(q_length, kv_length))
    b = torch.cat([a] * batch_size, dim=0)
    return b
    

In [4]:
def run_flash_attention(rank=0, batch_size=1, qo_len=128, kv_len=4096, num_qo_heads=32, num_kv_heads=32, head_dim=128, repeat=7, visualize_mask=False,device="cuda",return_tensors=False,verbose=False):
    def print_if_verbose(s):
        if verbose:
            print(s)
        return
    
    print_if_verbose(f"Running flash attention with rank {rank}, batch size {batch_size}, qo_len {qo_len}, kv_len {kv_len}, num_qo_heads {num_qo_heads}, num_kv_heads {num_kv_heads}, head_dim {head_dim}, visualize_mask {visualize_mask}, device {device}, verbose {verbose}")
    q = torch.randn(qo_len * batch_size, num_qo_heads, head_dim).half().to(device)
    k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(device)
    v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(device)
    mask = get_mask(qo_len, kv_len, rank, batch_size)
    mask = mask.to(device)

    compute_times = []
    for _ in range(repeat):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
        end_event.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()

        elapsed_time_ms = start_event.elapsed_time(end_event)
        compute_times.append(elapsed_time_ms)
        print_if_verbose(f"Elapsed time: {elapsed_time_ms:.2f} ms")
        torch.cuda.empty_cache()
    
    median_compute_time = torch.tensor(compute_times).median()
    return_values = [None, compute_times, median_compute_time]
    if return_tensors:
        return_values[0] = o_custom
    return return_values


In [5]:
results = {}
tp_size = 4

In [None]:
from collections import namedtuple

configs = dict(
    llama8b=dict(
        num_qo_heads=32,
        num_kv_heads=8,
        head_dim=128,
    ),
    llama70b=dict(
        num_qo_heads=64,
        num_kv_heads=8,
        head_dim=128,
    )
)

from multiprocessing import Process, Queue

if __name__ == "__main__":
    for name, model_config in configs.items():
        for k in range(10, 20+ 1):
            qo_len = kv_len = 2 ** k
            config = dict(
                rank=0,
                batch_size=1,
                qo_len=qo_len,
                kv_len=kv_len,
                num_qo_heads=model_config['num_qo_heads'] // tp_size,
                num_kv_heads=model_config['num_kv_heads'] // tp_size,
                head_dim=model_config['head_dim'],
            )
            try:
                proc = Process(
                    target=run_flash_attention, 
                    kwargs=dict(
                        **config,
                        repeat=7,
                        return_tensors=False,
                    )
                )
                item = run_flash_attention(
                    **config,
                    repeat=7,
                    return_tensors=False,
                )
                config['tp_size'] = tp_size
                computed_time = item[-1]
                print(f"k: {k}, computed_time: {computed_time:.2f}")
                
                config_named_tuple = namedtuple('Config', config.keys())
                config_named_tuple(**config)
                results[config_named_tuple] = computed_time
            except Exception as e:
                print(f"Error: {e}")
                print(f"Config: {config}")
                raise e
            torch.cuda.empty_cache()

k: 10, computed_time: 2.35
k: 11, computed_time: 0.29
k: 12, computed_time: 0.93
k: 13, computed_time: 2.41
k: 14, computed_time: 13.29
k: 15, computed_time: 33.12
k: 16, computed_time: 130.45
Error: CUDA out of memory. Tried to allocate 16.00 GiB. GPU 0 has a total capacity of 79.14 GiB of which 14.05 GiB is free. Including non-PyTorch memory, this process has 65.07 GiB memory in use. Of the allocated memory 64.59 GiB is allocated by PyTorch, and 2.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Config: {'rank': 0, 'batch_size': 1, 'qo_len': 131072, 'kv_len': 131072, 'num_qo_heads': 8, 'num_kv_heads': 2, 'head_dim': 128}


: 