In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads) -> None:
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        # Q, K, V projections matrix

        self.W_q = nn.Linear(d_model, d_model) # d_model -> d_model
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output projection matrix
        self.W_o = nn.Linear(d_model, d_model)

        # Softmax
        self.softmax = nn.Softmax(dim=-1) # dim=-1 means last dimension
    
    def forward(self, q, k, v):
        batch, time, dimension = q.shape # q.shape = (batch, time, d_model)
        n_d = self.d_model // self.num_heads 

        # self.d_model is the dimension of the model, n_d is the dimension of each head
        # arranging dimensions evenly across heads

        q, k, v = self.W_q(q), self.W_k(k), self.W_v(v)
        # projecting q, k, v to target space

        q = q.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        # .view() is used to reshape the tensor without changing the data
        # .permute() is used to rearrange the dimensions of the tensor
        # view: (batch, time, d_model) -> (batch, time, num_heads, n_d)
        # permute: (batch, time, num_heads, n_d) -> (batch, num_heads, time, n_d)

        # time * n_d is the dimension of the query and key vectors
        # reshape for parallel computation

        score = q @ k.transpose(2, 3) / math.sqrt(n_d)
        # @ is matrix multiplication or dot product
        # * is element-wise multiplication
        # / math.sqrt(n_d) is scaling factor to normalize the dot product

        mask = torch.tril(torch.ones(time, time, dtype=bool))
        # torch.triu() is used to create a triangular matrix with upper triangular elements
        # torch.tril() is used to create a triangular matrix with lower triangular elements
        # torch.ones() is used to create a tensor of ones

        score = score.masked_fill(mask == 0, float('-inf'))
        score = self.softmax(score) @ v

        score = score.permute(0, 2, 1, 3).contiguous().reshape(batch, time, dimension)
        # .contiguous() is used to make the tensor contiguous in memory

        return self.W_o(score)

In [None]:
# 一行代码测试
attention = MultiHeadAttention(d_model=512, num_heads=8)
output = attention(torch.randn(2, 5, 512), torch.randn(2, 5, 512), torch.randn(2, 5, 512))
print(f"快速测试结果: 输入形状 (2, 5, 512) -> 输出形状 {output.shape}")