In [15]:
import torch
import torch.nn.functional as F
import time

def attention(Q, K, V):

    """
    参数：
    - Q: 查询矩阵，形状为 (batch_size, seq_len_q, d_k)
    - K: 键矩阵，形状为 (batch_size, seq_len_k, d_k)
    - V: 值矩阵，形状为 (batch_size, seq_len_v, d_v)

    返回：
    - output: 注意力加权后的值矩阵，形状为 (batch_size, seq_len_q, d_v)
    """

    # 计算注意力分数
    scores = torch.matmul(Q, K.permute(0, 2, 1)) / torch.sqrt(torch.tensor(Q.size(-1)).float())
    
    # 使用softmax计算权重
    weights = F.softmax(scores, dim=-1)
    
    # 对权重加权求和得到注意力结果
    output = torch.matmul(weights, V)
    
    return output, weights

# 随机生成模拟数据
torch.manual_seed(42)
Q = torch.rand((3, 4, 5))  # 3个查询向量，每个查询向量有4个维度，总共有5个查询
K = torch.rand((3, 4, 5))  # 3个键向量，每个键向量有4个维度，总共有5个键
V = torch.rand((3, 4, 8))  # 3个数值向量，每个数值向量有4个维度，总共有8个数值


# 创建 CPU 时间记录
start_time = time.time()

# 执行注意力机制函数
output, weights = attention(Q, K, V)

# 记录结束时间
end_time = time.time()

# 计算执行时间
elapsed_time = (end_time - start_time) * 1000  # 转换为毫秒
print(f"执行时间: {elapsed_time:.2f} 毫秒")


# 打印结果
print("Attention输出：", output)
print("Attention权重：", weights)


执行时间: 2.03 毫秒
Attention输出： tensor([[[0.4623, 0.3159, 0.6992, 0.5253, 0.4834, 0.3035, 0.5636, 0.2056],
         [0.4446, 0.3195, 0.6882, 0.5235, 0.5048, 0.2769, 0.5689, 0.2020],
         [0.4413, 0.3142, 0.6860, 0.5176, 0.5019, 0.2756, 0.5704, 0.2002],
         [0.4651, 0.3211, 0.6985, 0.5274, 0.4839, 0.3027, 0.5601, 0.2062]],

        [[0.5355, 0.4025, 0.6384, 0.1917, 0.7205, 0.5478, 0.4214, 0.4450],
         [0.5341, 0.4074, 0.6449, 0.1862, 0.7191, 0.5508, 0.4245, 0.4481],
         [0.5426, 0.3713, 0.6287, 0.2080, 0.7285, 0.5407, 0.4197, 0.4284],
         [0.5390, 0.3879, 0.6343, 0.1990, 0.7242, 0.5446, 0.4208, 0.4375]],

        [[0.4563, 0.6375, 0.7427, 0.5676, 0.3553, 0.3468, 0.3884, 0.4641],
         [0.4604, 0.6426, 0.7411, 0.5681, 0.3496, 0.3468, 0.3918, 0.4774],
         [0.4909, 0.6675, 0.7252, 0.5581, 0.3440, 0.3570, 0.4080, 0.5080],
         [0.4436, 0.6311, 0.7478, 0.5745, 0.3518, 0.3393, 0.3854, 0.4677]]])
Attention权重： tensor([[[0.2313, 0.2421, 0.2860, 0.2407],
         [0