In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
x = torch.randn(1, 3, 4)

In [3]:
num_head = 2
head_dim = 2

In [4]:
linear_q = nn.Linear(4, 4)
linear_k = nn.Linear(4, 4)
linear_v = nn.Linear(4, 4)

linear_out = nn.Linear(4, 4)

Q = linear_q(x)
K = linear_k(x)
V = linear_v(x)

In [5]:
# 分割
def split_head(x, num_head):
    batch_size, seq_len, feature_dim = x.size()
    head_dim = feature_dim // num_head
    output = x.view(batch_size, seq_len, num_head, head_dim).transpose(1, 2)
    return output

In [6]:
Q = split_head(Q, num_head)
K = split_head(K, num_head)
V = split_head(V, num_head)

In [7]:
similarity_score = torch.matmul(Q, K.transpose(-2, -1))
print('similarity_score:\n', similarity_score)
print('size:\n', similarity_score.size())
scale_factor = similarity_score.size(-1) ** 0.5
print('scale_factor:\n', scale_factor)
scale_weight = similarity_score / scale_factor
print('scale_weight:\n', scale_weight)
print('size:\n', scale_weight.size())
attn_weight = F.softmax(scale_weight, dim=-1)
print('attn_weight:\n', attn_weight)
print('size:\n', attn_weight.size())
attn_output = torch.matmul(similarity_score, V)
print('attn_output:\n', attn_output)
print('size:\n', attn_output.size())

similarity_score:
 tensor([[[[-1.3846, -0.3331,  0.3355],
          [-0.5084, -0.0064,  0.2133],
          [ 0.4380,  0.2477,  0.0046]],

         [[ 0.7266,  0.2095, -0.7258],
          [ 0.1571,  0.0338, -0.2635],
          [-0.7530, -0.2057,  0.8573]]]], grad_fn=<UnsafeViewBackward0>)
size:
 torch.Size([1, 2, 3, 3])
scale_factor:
 1.7320508075688772
scale_weight:
 tensor([[[[-0.7994, -0.1923,  0.1937],
          [-0.2935, -0.0037,  0.1232],
          [ 0.2529,  0.1430,  0.0026]],

         [[ 0.4195,  0.1209, -0.4190],
          [ 0.0907,  0.0195, -0.1522],
          [-0.4347, -0.1188,  0.4949]]]], grad_fn=<DivBackward0>)
size:
 torch.Size([1, 2, 3, 3])
attn_weight:
 tensor([[[[0.1807, 0.3316, 0.4878],
          [0.2595, 0.3468, 0.3937],
          [0.3739, 0.3350, 0.2911]],

         [[0.4599, 0.3412, 0.1988],
          [0.3682, 0.3429, 0.2888],
          [0.2039, 0.2796, 0.5165]]]], grad_fn=<SoftmaxBackward0>)
size:
 torch.Size([1, 2, 3, 3])
attn_output:
 tensor([[[[-2.0152,  0.610

In [None]:
# 合并
def combine_head(x):
    batch_size, num_head, seq_len, head_dim = x.size()
    feature_dim = num_head * head_dim
    output = x.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)
    # tensor.contiguous()确保张量内存布局是连续的，将非连续张量转为连续张量，如果张量已是连续，那么直接返回原张量
    # 配合使用
    # view()改变张量形状，要求张量内存布局必须连续
    # 如在调用前对张量进行transpose()，permute()或切片等操作，可能会导致张量内存布局非连续，此时需先调用contiguous()
    # 在GPU上，非连续张量可能会导致某些操作失败或效率低下，因此在将张量传递到GPU前，需确保其连续性
    # 调用contiguous()会创建新张量
    # 在某些情况下，contiguous()会引入额外的内存复制开销，建议只在必要时使用
    return output

In [9]:
attn_output = linear_out(combine_head(attn_output))
print('attn_output:\n', attn_output)

attn_output:
 tensor([[[-0.4509,  0.9675, -1.0888,  0.9564],
         [-0.2544,  0.5803, -0.5386,  0.6358],
         [-0.0670,  0.1700,  0.4976,  0.5800]]], grad_fn=<ViewBackward0>)
