In [51]:
import torch
import time
batch_size = 90
device="cuda:1"
query_states_c = torch.randn(batch_size, 32, 1, 128, dtype=torch.bfloat16, device="cpu")
query_states = torch.randn(batch_size, 32, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
# 480MB, *4 = 1920MB
key_states = torch.randn(batch_size, 8, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
value_states = torch.randn(batch_size, 8, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()


In [52]:
time_start = time.time()
query_states_cgpu = query_states_c.pin_memory()
qgpu_c = query_states_cgpu.to(device="cuda:1")
print(f"pin and move q cost {time.time() - time_start} s")
time_start = time.time()
qgpu = query_states.to(device="cuda:1")
print(f"move q cost {time.time() - time_start} s")

pin and move q cost 0.003097057342529297 s
move q cost 0.031256675720214844 s


In [53]:
# 960MB*3 = 2880MB
time_start = time.time()
for i in range(1):
    gpu_key_states = key_states.to(device="cuda:1")
    gpu_value_states = value_states.to(device="cuda:1")
print(f"time cost to cuda:1 {time.time() - time_start} s")

time cost to cuda:1 0.015724897384643555 s


In [54]:
from typing import Optional
import math
def scaled_dot_product_attention_with_pinned_memory(
    query: torch.Tensor,
    key: torch.Tensor, 
    value: torch.Tensor,
    output_tensor: Optional[torch.Tensor] = None,
    attn_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
    enable_gqa: bool = False
) -> torch.Tensor:
    """
    修改版的 scaled_dot_product_attention，支持将结果直接写入预分配的 pinned memory。
    
    Args:
        query: Query tensor of shape (..., L, E)
        key: Key tensor of shape (..., S, E) 
        value: Value tensor of shape (..., S, Ev)
        output_tensor: 预分配的 pinned memory tensor，用于存储结果。如果为 None，则创建新的 tensor
        attn_mask: 可选的注意力掩码
        dropout_p: Dropout 概率
        is_causal: 是否使用因果掩码
        scale: 缩放因子，如果为 None 则使用 1/sqrt(E)
        enable_gqa: 是否启用分组查询注意力
        
    Returns:
        注意力输出 tensor，如果提供了 output_tensor 则返回该 tensor
    """
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    
    # 计算输出形状
    output_shape = query.shape[:-1] + (value.shape[-1],)
    
    # 如果提供了预分配的输出 tensor，验证其形状和类型
    if output_tensor is not None:
        # 注意：对于 GPU tensor，即使原始是 pinned memory，移动到 GPU 后也不再是 pinned
        if output_tensor.shape != output_shape:
            raise ValueError(f"output_tensor 形状 {output_tensor.shape} 与期望形状 {output_shape} 不匹配")
        if output_tensor.dtype != query.dtype:
            raise ValueError(f"output_tensor 数据类型 {output_tensor.dtype} 与 query 数据类型 {query.dtype} 不匹配")
        if output_tensor.device != query.device:
            raise ValueError(f"output_tensor 设备 {output_tensor.device} 与 query 设备 {query.device} 不匹配")
    
    # 创建注意力偏置
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    
    if is_causal:
        assert attn_mask is None, "is_causal 和 attn_mask 不能同时使用"
        temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias = attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    # 处理分组查询注意力
    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    # 计算注意力权重
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    
    # 应用 dropout
    if dropout_p > 0.0:
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    
    # 计算最终输出
    if output_tensor is not None:
        # 直接写入预分配的 pinned memory
        torch.matmul(attn_weight, value, out=output_tensor)
        return output_tensor
    else:
        # 创建新的 tensor
        return attn_weight @ value

In [55]:

hidden_states_cache = torch.zeros(batch_size, 8, 4, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1: 
        return hidden_states
    # 改为将hidden_states 复制到 hidden_states_cache 上
    # for i in range(hidden_states_cache.shape[2]):
        # 这行代码的作用是：将 hidden_states 的内容复制到 hidden_states_cache 的第 i 个位置（第三个维度）。
        # 也就是说，对于 hidden_states_cache 的每个“副本”维度 i，都用同样的 hidden_states 填充，实现了在第3维上重复 hidden_states 的效果。
        # hidden_states_cache[:, :, i, :, :].index_copy_(
        #     2, torch.arange(hidden_states.shape[2]), hidden_states
        # )
        # hidden_states_cache[:, :, i, :, :] = hidden_states
        # hidden_states_cache[:, :, i, :, :].index_copy_(3, torch.arange(hidden_states.shape[3]), hidden_states)
    # hidden_states_cache = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    hidden_states_cache = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states_cache.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
print(f"key_states {key_states.device}")
time_start = time.time()
new_key_states = repeat_kv(key_states, 4)
new_value_states = repeat_kv(value_states, 4)

torch.cuda.synchronize()
print(f"time cost repeat kv {time.time() - time_start} s")

time_start = time.time()
if torch.isnan(key_states).any():
    print("NaN detected in key_states")
if torch.isnan(value_states).any():
    print("NaN detected in value_states")
print(f"time cost check nan {time.time() - time_start} s")
attn_output_c = torch.empty_like(query_states, device="cpu", dtype=torch.bfloat16, pin_memory=True)
def func_here():
    time_start = time.time()
    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query_states,
        new_key_states,
        new_value_states,
        attn_mask=None,
        dropout_p=0.0,
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal=True,
    )
    # print(attn_output[0])
    print(f"dot attn cost {time.time()-time_start:.6f} seconds")
def func_pin():
    time_start = time.time()
    attn_output = scaled_dot_product_attention_with_pinned_memory(
        query_states,
        new_key_states,
        new_value_states,
        output_tensor=attn_output_c,
        attn_mask=None,
        dropout_p=0.0,
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal=True,
    )
    # print(attn_output[0])
    print(f"dot attn pin cost {time.time()-time_start:.6f} seconds")
func_here()
func_pin()

key_states cpu
time cost repeat kv 0.11120891571044922 s
time cost check nan 0.01766490936279297 s
dot attn cost 0.182348 seconds
dot attn pin cost 0.702972 seconds


In [3]:
# 测量移动整层layer需要的耗时
import torch
import time
batch_size = 360
device="cuda:1"
# hd = 16384
# hdd= 6144
hd = 14336
hdd= 4096
w0 = torch.randn(hd, hdd, dtype=torch.bfloat16, device="cpu").pin_memory()
w1 = torch.randn(hd, hdd, dtype=torch.bfloat16, device="cpu").pin_memory()
w2 = torch.randn(hd, hdd, dtype=torch.bfloat16, device="cpu").pin_memory()

time_start = time.time()
for i in range(8):
    w0_gpu = w0.to(device="cuda:1")
    w1_gpu = w1.to(device="cuda:1")
    w2_gpu = w2.to(device="cuda:1")
print(f"time cost to cuda:1 {time.time() - time_start} s")

time cost to cuda:1 0.2299492359161377 s
