In [20]:
import torch
import time
import threading
import queue
from concurrent.futures import ThreadPoolExecutor
from math import *
dev = "cpu"
dtype = torch.bfloat16

# 设置线程数
torch.set_num_threads(192)

# 准备 SDPA 数据
batch_size = 512
num_heads = 32
seq_len = 512
head_dim = 128
q_seq_len = 1  # decode 阶段

query = torch.randn(batch_size, num_heads, q_seq_len, head_dim, device=dev, dtype=dtype)
key = torch.randn(batch_size, 8, seq_len, head_dim, device=dev, dtype=dtype)
value = torch.randn(batch_size, 8, seq_len, head_dim, device=dev, dtype=dtype)

In [21]:


time_start_move = time.time()
key_pin = key.pin_memory()
key_gpu = key_pin.to("cuda:1")
time_end_move = time.time()
del key_gpu
torch.cuda.empty_cache()
print(f"move time {time_end_move - time_start_move}")


move time 0.10378456115722656


In [22]:
torch.set_num_threads(160)
from torch.nn.attention import  sdpa_kernel
from torch.nn.attention import  SDPBackend
import time
def compute_sdpa(query, key, value, num_runs=1):
    """执行 SDPA 计算"""
    results = []
    time_list = []
    for _ in range(num_runs):
        time_start_single = time.time()
        with torch.no_grad():
            output = torch.nn.functional.scaled_dot_product_attention(
                query, key, value,
                is_causal=True,
                dropout_p=0.0
            )
        time_end_single = time.time()
        time_list.append(round(time_end_single - time_start_single,6))
        results.append(output)
    print(f"sdpa time {time_list}")
    return results
@torch.no_grad()
def scaled_dot_product_attention_help(
    query_states, 
    key_states, 
    value_states, 
    attn_mask=None, dropout_p=0.0, enable_gqa=False, is_causal=False, output_tensor=None):

    time_start = time.time()
   
    num_query_heads = query_states.shape[1]     # e.g. 32
    num_key_heads = key_states.shape[1]
    num_groups = int(num_query_heads//num_key_heads)   # 4 组
    
    query_states = query_states
   
    if output_tensor is None:
        output_tensor = torch.zeros(
            query_states.shape, dtype=query_states.dtype, device=query_states.device, pin_memory=False
        )
    else:
        output_tensor = output_tensor.contiguous()
    
    query_groups = []
    query_indices_list = []
    
    for group_idx in range(num_groups):
        query_indices = torch.arange(group_idx, num_query_heads, num_groups)
        query_group = query_states[:, query_indices, :, :]  # 确保连续内存
        query_groups.append(query_group)
        query_indices_list.append(query_indices)
    
    for group_idx in range(num_groups):
        query_group = query_groups[group_idx]
        query_indices = query_indices_list[group_idx]
        
        
        # 优化5: 使用预取的数据，避免重复内存访问
        key_group = key_states    # (batch, 8, seq_len, head_dim)
        value_group = value_states # (batch, 8, seq_len, head_dim)

        time_start_tmp = time.time()
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
            attn_out = torch.nn.functional.scaled_dot_product_attention(
                query_group, key_group, value_group,
                attn_mask=attn_mask,
                dropout_p=dropout_p,
                enable_gqa=enable_gqa,
                is_causal=is_causal
            )
        print(f"single group {group_idx} real attn out cost {time.time() - time_start_tmp} s")
        
        time_start_cpy = time.time()
        output_tensor[:, query_indices, :, :] = attn_out
        print(f"write to output tensor cost {time.time() - time_start_cpy} s")
    print(f"dot attn help cost {time.time()-time_start:.6f} seconds")
    return output_tensor

@torch.no_grad()
def scaled_dot_product_attention_help_split_kv(
    query_states, 
    key_states, 
    value_states, 
    attn_mask=None, dropout_p=0.0, enable_gqa=False, is_causal=False, output_tensor=None):
    """
    将 key_group 和 value_group 拆分成两份，query_group 也相应拆分成两份，分别计算
    确保计算结果准确，符合原函数的根本计算逻辑
    """
    time_start = time.time()
   
    num_query_heads = query_states.shape[1]     # e.g. 32
    num_key_heads = key_states.shape[1]         # e.g. 8
    num_groups = int(num_query_heads // num_key_heads)   # 4 组
    
    query_states = query_states.contiguous()
    key_states = key_states.contiguous()
    value_states = value_states.contiguous()
   
    if output_tensor is None:
        output_tensor = torch.zeros(
            query_states.shape, dtype=query_states.dtype, device=query_states.device, pin_memory=False
        )
    else:
        output_tensor = output_tensor.contiguous()
    
    query_groups = []
    query_indices_list = []
    
    for group_idx in range(num_groups):
        query_indices = torch.arange(group_idx, num_query_heads, num_groups)
        query_group = query_states[:, query_indices, :, :] # 确保连续内存
        query_groups.append(query_group)
        query_indices_list.append(query_indices)
    
    for group_idx in range(num_groups):
        query_group = query_groups[group_idx]
        query_indices = query_indices_list[group_idx]
        
        # 将 key_group 和 value_group 拆分成两份
        # key_group: (batch, 8, seq_len, head_dim) -> 拆成两份，每份 (batch, 4, seq_len, head_dim)
        num_kv_heads = key_states.shape[1]
        num_kv_splits = 4  # 拆分成两份
        kv_heads_per_split = num_kv_heads // num_kv_splits  # 每份的 head 数
        
        # 将 query_group 也拆分成两份，对应 KV 的拆分
        num_query_heads_in_group = query_group.shape[1]
        query_heads_per_split = num_query_heads_in_group // num_kv_splits  # 每份的 query head 数
        
        # 存储每份的计算结果
        attn_out_parts = []
        
        for split_idx in range(num_kv_splits):
            # 拆分 KV: 第一份 [0:4], 第二份 [4:8]
            kv_start = split_idx * kv_heads_per_split
            kv_end = kv_start + kv_heads_per_split
            key_split = key_states[:, kv_start:kv_end, :, :]  # (batch, 4, seq_len, head_dim)
            value_split = value_states[:, kv_start:kv_end, :, :]  # (batch, 4, seq_len, head_dim)
            
            # 拆分 query: 需要按照原始 stride 方式分组
            # query_group 的 heads 是通过 stride=num_groups 方式选择的
            # 例如 group 0: [0, 4, 8, 12, 16, 20, 24, 28]
            # 拆分成两份：第一份 [0, 4, 8, 12]，第二份 [16, 20, 24, 28]
            # 在 query_group 中，这些 heads 的索引是 [0, 1, 2, 3] 和 [4, 5, 6, 7]
            query_start = split_idx * query_heads_per_split
            query_end = query_start + query_heads_per_split
            query_split = query_group[:, query_start:query_end, :, :]  # (batch, 4, q_seq_len, head_dim)
            
            # 计算这一份的 attention
            time_start_tmp = time.time()
            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
                attn_out_split = torch.nn.functional.scaled_dot_product_attention(
                    query_split, key_split, value_split,
                    attn_mask=attn_mask,
                    dropout_p=dropout_p,
                    enable_gqa=enable_gqa,
                    is_causal=is_causal
                )
            print(f"single group {group_idx} split {split_idx} real attn out cost {time.time() - time_start_tmp} s")
            attn_out_parts.append(attn_out_split)
        
        # 合并两份的计算结果
        time_start_concat = time.time()
        attn_out = torch.cat(attn_out_parts, dim=1)  # 在 head 维度上拼接
        print(f"single group {group_idx} concat cost {time.time() - time_start_concat} s")
        
        # 将结果写入 output_tensor
        time_start_cpy = time.time()
        output_tensor[:, query_indices, :, :] = attn_out
        print(f"write to output tensor cost {time.time() - time_start_cpy} s")
    
    print(f"dot attn help split kv cost {time.time()-time_start:.6f} seconds")
    return output_tensor



num_runs = 3
time_start = time.time()
for i in range(num_runs):
    # sdpa_results1 = scaled_dot_product_attention_help_split_kv(query, key, value)
    sdpa_results2 = scaled_dot_product_attention_help(query, key, value)
    
    # 验证两个计算结果是否相同
    # if torch.equal(sdpa_results1, sdpa_results2):
    #     print(f"Run {i+1}: Results are exactly equal ✓")
    # elif torch.allclose(sdpa_results1, sdpa_results2, rtol=1e-5, atol=1e-8):
    #     max_diff = (sdpa_results1 - sdpa_results2).abs().max().item()
    #     print(f"Run {i+1}: Results are close (max diff: {max_diff:.2e}) ✓")
    # else:
    #     max_diff = (sdpa_results1 - sdpa_results2).abs().max().item()
    #     mean_diff = (sdpa_results1 - sdpa_results2).abs().mean().item()
    #     print(f"Run {i+1}: Results differ! Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e} ✗")
    #     assert False, f"Results don't match: max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}"
sdpa_time = time.time() - time_start
print(f"sdpa time {sdpa_time}")

single group 0 real attn out cost 0.14505553245544434 s
write to output tensor cost 0.028368234634399414 s
single group 1 real attn out cost 0.058889150619506836 s
write to output tensor cost 0.0022513866424560547 s
single group 2 real attn out cost 0.009257316589355469 s
write to output tensor cost 0.00022125244140625 s
single group 3 real attn out cost 0.009157419204711914 s
write to output tensor cost 0.00022602081298828125 s
dot attn help cost 0.396262 seconds
single group 0 real attn out cost 0.009003877639770508 s
write to output tensor cost 0.00022411346435546875 s
single group 1 real attn out cost 0.00862884521484375 s
write to output tensor cost 0.00022935867309570312 s
single group 2 real attn out cost 0.008588790893554688 s
write to output tensor cost 0.00022411346435546875 s
single group 3 real attn out cost 0.00838160514831543 s
write to output tensor cost 0.00020956993103027344 s
dot attn help cost 0.037035 seconds
single group 0 real attn out cost 0.008901834487915039 s


In [None]:
device1 = "cuda:1"
device2 = "cuda:2"

expert1 = torch.randn(14336, 4096, dtype=torch.bfloat16, device="cpu", pin_memory=True)
expert2 = torch.randn(14336, 4096, dtype=torch.bfloat16, device="cpu", pin_memory=True)

time_start = time.time()
expert1.to(device1, non_blocking=True)
# expert1.to(device2, non_blocking=True)
torch.cuda.synchronize(device=device1)
# torch.cuda.synchronize(device=device2)
time_end = time.time()
print(f"move time {time_end - time_start}")

time_start = time.time()
expert1.to(device1, non_blocking=True)
# expert1.to(device2, non_blocking=True)
torch.cuda.synchronize(device=device1)
# torch.cuda.synchronize(device=device2)
time_end = time.time()
print(f"move time {time_end - time_start}")

time_start = time.time()
expert1.to(device1, non_blocking=True)
expert1.to(device2, non_blocking=True)
torch.cuda.synchronize(device=device1)
torch.cuda.synchronize(device=device2)
time_end = time.time()
print(f"move time {time_end - time_start}")


time_start = time.time()
expert1.to(device1, non_blocking=True)
expert2.to(device2, non_blocking=True)
torch.cuda.synchronize(device=device1)
torch.cuda.synchronize(device=device2)
time_end = time.time()
print(f"move time {time_end - time_start}")



move time 0.009891986846923828
move time 0.00965428352355957
move time 0.017072439193725586
move time 0.009665727615356445


In [None]:
# 14ms 10 experts
# all   93 ms
import time, torch
k = torch.randn(1408, 2048, dtype=torch.bfloat16, device="cpu", pin_memory=True)
num = 64
time_start_torch = time.time()
for i in range(num*3):
    k_g = k.to("cuda:1")
time_end_torch = time.time()

print(f"allocate torch tensor time: {time_end_torch - time_start_torch} seconds")

allocate torch tensor time: 0.0935208797454834 seconds


: 