In [1]:
import torch
import time
batch_size = 1440
device="cuda:1"
query_states = torch.randn(batch_size, 32, 1, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
# 480MB, *4 = 1920MB
key_states = torch.randn(batch_size, 8, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
value_states = torch.randn(batch_size, 8, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()



In [26]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1: 
        return hidden_states
    # 正确的repeat_kv实现
    expanded = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return expanded.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
time_start = time.time()
torch.cuda.nvtx.range_push("repeat kv")
new_key_states = repeat_kv(key_states, 4)
new_value_states = repeat_kv(value_states, 4)
torch.cuda.nvtx.range_pop()
print(f"time cost repeat kv {time.time() - time_start} s")

torch.cuda.nvtx.range_push("dot attention")
time_start = time.time()
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    new_key_states,
    new_value_states,
    attn_mask=None,
    dropout_p=0.0,
    enable_gqa = False,
    # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
    is_causal=False,
)
print(f"dot attn cost {time.time()-time_start:.6f} seconds")
torch.cuda.nvtx.range_pop()


time cost repeat kv 1.4284155368804932 s
dot attn cost 0.067475 seconds


-2

In [27]:

# 按多组分别对kv和query分头计算，再汇总结果
time_start = time.time()
num_groups = 4   # 4 组
num_query_heads = query_states.shape[1]     # e.g. 32
heads_per_group = num_query_heads // num_groups   # e.g. 8

attn_outputs_per_group = []
for i in range(1):
    for group_idx in range(num_groups):
        # 不扩展kv，而是使用整个kv，query取值需变化
        # 每组使用整个key/value heads，但query heads按步数提取
        # 例如：32个query heads，8个key/value heads
        # 按步数提取query heads
        query_indices = torch.arange(group_idx, num_query_heads, num_groups)
        query_group = query_states[:, query_indices, :, :]  # (batch, heads_per_group, seq_len, head_dim)
        print(query_group.shape)
        key_group = key_states    # (batch, 8, seq_len, head_dim)
        value_group = value_states # (batch, 8, seq_len, head_dim)

        time_start_tmp = time.time()
        attn_out = torch.nn.functional.scaled_dot_product_attention(
            query_group, key_group, value_group,
            attn_mask=None,
            dropout_p=0.0,
            enable_gqa=False,
            is_causal=False
        )
        print(f"real attn out cost {time.time() - time_start_tmp} s")
        # print(f"Group {group_idx}: query_indices={query_indices.tolist()}, kv_heads=all, attn_out.shape={attn_out.shape}")
        # instead of simply appending to a list, collect all group results in the correct head positions
        # (batch, num_query_heads, seq_len, head_dim) for all groups
        if group_idx == 0:
            attn_outputs_full = torch.zeros(
                query_states.shape, dtype=attn_out.dtype, device=attn_out.device, pin_memory=True
            )
        attn_outputs_full[:, query_indices, :, :] = attn_out

print(f"dot attn cost {time.time()-time_start:.6f} seconds")

time_start_move = time.time()
attn_outputs_full_gpu = attn_outputs_full.to(device="cuda:1")
print(f"time cost move to cuda:1 {time.time() - time_start_move} s")

if torch.allclose(attn_output, attn_outputs_full, atol=1e-6):
    print("attn_output == attn_output_group")
else:
    print("attn_output != attn_output_group")

torch.Size([1440, 8, 1, 128])
real attn out cost 0.039719581604003906 s
torch.Size([1440, 8, 1, 128])
real attn out cost 0.03912043571472168 s
torch.Size([1440, 8, 1, 128])
real attn out cost 0.038926124572753906 s
torch.Size([1440, 8, 1, 128])
real attn out cost 0.039426565170288086 s
dot attn cost 0.162191 seconds
time cost move to cuda:1 0.0032503604888916016 s
attn_output == attn_output_group


In [22]:
#==== 多线程版本 ====
import threading

print("\n=== Multithreaded group attention ===")

for i in range(1):
    time_start = time.time()
    attn_outputs_full_threaded = torch.zeros(
        query_states.shape, dtype=query_states.dtype, device=query_states.device
    )

    def compute_group_attn(group_idx):
        query_indices = torch.arange(group_idx, num_query_heads, num_groups)
        query_group = query_states[:, query_indices, :, :]
        key_group = key_states
        value_group = value_states
        attn_out = torch.nn.functional.scaled_dot_product_attention(
            query_group, key_group, value_group,
            attn_mask=None,
            dropout_p=0.0,
            enable_gqa=False,
            is_causal=False
        )
        attn_outputs_full_threaded[:, query_indices, :, :] = attn_out

    threads = []
    for group_idx in range(num_groups):
        t = threading.Thread(target=compute_group_attn, args=(group_idx,))
        threads.append(t)
        t.start()

    for t in threads:
        t.join()

    print(f"dot attn multithreaded cost {time.time()-time_start:.6f} seconds")

    if torch.allclose(attn_output, attn_outputs_full_threaded, atol=1e-6):
        print("attn_output == attn_output_group_threaded")
    else:
        print("attn_output != attn_output_group_threaded")


=== Multithreaded group attention ===
dot attn multithreaded cost 0.181015 seconds
attn_output == attn_output_group_threaded
