In [39]:
import torch
import time
batch_size = 90
device="cuda:0"
query_states = torch.randn(batch_size, 32, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
# 480MB, *4 = 1920MB
key_states = torch.randn(batch_size, 8, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
value_states = torch.randn(batch_size, 8, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
# 960MB*3 = 2880MB
# time_start = time.time()
# for i in range(4):
#     key_states.to(device="cuda:1")
#     value_states.to(device="cuda:1")
# print(f"time cost to cuda:1 {time.time() - time_start} s")

In [53]:
import time
hidden_states_cache = torch.zeros(batch_size, 8, 4, 512, 128, dtype=torch.bfloat16, device="cpu").pin_memory()
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    # 改为将hidden_states 复制到 hidden_states_cache 上
    # for i in range(hidden_states_cache.shape[2]):
        # 这行代码的作用是：将 hidden_states 的内容复制到 hidden_states_cache 的第 i 个位置（第三个维度）。
        # 也就是说，对于 hidden_states_cache 的每个“副本”维度 i，都用同样的 hidden_states 填充，实现了在第3维上重复 hidden_states 的效果。
        # hidden_states_cache[:, :, i, :, :] = hidden_states
        # hidden_states_cache[:, :, i, :, :].index_copy_(3, torch.arange(hidden_states.shape[3]), hidden_states)
    hidden_states_cache = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    # hidden_states_cache = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states_cache.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def repeat_kv_optimized_v2(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    使用expand + reshape优化版本
    """
    if n_rep == 1:
        return hidden_states
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    # 在第1维上扩展n_rep倍
    expanded = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return expanded.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# attn_mask = torch.ones((480,512), dtype=torch.bfloat16, device="cpu")
print(f"key_states {key_states.device}")
time_start = time.time()
new_key_states = repeat_kv(key_states, 4)
new_value_states = repeat_kv(value_states, 4)
print(f"time cost repeat kv {time.time() - time_start} s")

time_start = time.time()
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    new_key_states,
    new_value_states,
    attn_mask=None,
    dropout_p=0.0,
    # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
    is_causal=False,
)
print(f"dot attn cost {time.time()-time_start:.6f} seconds")

key_states cpu
time cost repeat kv 0.12833380699157715 s
dot attn cost 0.559304 seconds
