## Transformer代码

手撕MHA

公式 

$Attention(Q,K,V)= \frac{softmax(QK^T)}{\sqrt{d_k}} V $

$Q_i = QW^q_i, K_i = KW^k_i, V_i = VW^v_i$

$head_i = Attention(Q_i, K_i, V_i)$

$MultiHead(Q,K,V) = Concat(head_i)*W^o$


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        
        assert dim % num_heads == 0, "dim 必须能被 num_heads 整除"
        
        self.w_k = nn.Linear(dim, dim)
        self.w_q = nn.Linear(dim, dim)
        self.w_v = nn.Linear(dim, dim)
        self.fc = nn.Linear(dim, dim)
    
    
    def forward(self, input_tensor, attention_mask=None):
        B, L, dim = input_tensor.shape

        key = self.w_k(input_tensor)
        query = self.w_q(input_tensor)
        value = self.w_v(input_tensor)

        key = key.reshape(B, L, self.num_heads, self.dim // self.num_heads).transpose(1, 2)
        query = query.reshape(B, L, self.num_heads, self.dim // self.num_heads).transpose(1, 2)
        value = value.reshape(B, L, self.num_heads, self.dim // self.num_heads).transpose(1, 2)

        attn_scores = torch.matmul(query, key.transpose(-2, -1)) / ((self.dim // self.num_heads) ** 0.5)
        
        causal_mask = torch.tril(torch.ones(L, L)).unsqueeze(0).unsqueeze(0)

        if attention_mask is not None:
            causal_mask = causal_mask & attention_mask

        attn_scores = attn_scores.masked_fill(causal_mask==0, float('-inf'))

        attn_scores = F.softmax(attn_scores - attn_scores.max(dim=-1, keepdim=True), dim=-1)
        
        outputs = torch.matmul(attn_scores, value).transpose(1, 2).reshape(B, L, dim).contiguous()

        outputs = self.fc(outputs)

        return outputs
