In [1]:
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        assert (hidden_size % num_heads) == 0, "hidden_size must be divisible by num_heads"

        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)

        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_state, attention_mask=None):
        #hidden_state就是输入的 doc的句子，shape为 [batch_size, seq_len, hidden_size]
        batch_size = hidden_state.size()[0]

        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        query = self.split_head(query)
        key = self.split_head(key)
        # value shape为 [batch_size, num_heads, seq_len, hidden_dim]
        value = self.split_head(value)

        ## 计算注意力分数，-1 为 hidden_size,-2为seqlen
        # 输出为 [batch_size, num_heads, seq_len, seq_len]
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))

        if attention_mask != None:
            attention_scores += attention_mask * -1e-9

        ## 对注意力分数进行归一化，按列 进行softmax， 每一列数值范围0-1，和为1
        attention_probs = torch.softmax(attention_scores, dim=-1) #解决ICS问题

        # 输出output的形状为 [batch_size, num_heads, seq_len, hidden_dim]
        output = torch.matmul(attention_probs, value)

        ## 对注意力输出进行拼接，contiguous()确保内存连续性
        # 输出形状为 [batch_size, seq_len, hidden_size]
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)

        output = self.o_linear(output)

        return output


    def split_head(self, x):
        batch_size = x.size()[0]
        # -1位置的dim，自动计算，代表seqlen
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

In [2]:
import torch.nn as nn

class MHA(nn.modules):
    def __init__(self, num_head, hidden_size):
        super(MHA,self).__init__()
        self.num_head = num_head
        self.hidden_size = hidden_size
        self.head_dim = hidden_size // num_head
        self.qlinear = nn.Linear(hidden_size, hidden_size)
        self.klinear = nn.Linear(hidden_size, hidden_size)
        self.vlinear = nn.Linear(hidden_size, hidden_size)
        self.olinear = nn.Linear(hidden_size, hidden_size)

    def forward(self,x, mask=None):

        q = self.qlinear(x)
        k = self.klinear(x)
        v = self.vlinear(x)

        q = self.split_head(q)
        k = self.split_head(k)
        v = self.split_head(v)

        att_score = torch.matul(q, k.transpose(-1,-2)) // torch.sqrt(torch.tensor(self.head_dim))

        if mask:
            att_score += mask * -1e-9

        att_prob = torch.softmax(att_score, dim = -1)
        batch_size = x.size()[0]
        output = torch.matul(att_prob, v).transpose(1,2).contiguous().view(batch_size, -1, self.num_head * self.hidde_dim)
        output = self.olinear(output)
        return output

    def split_head(self, x):
        batch_size = x.size()[0]
        return x.view(batch_size, -1, self.num_head, self.hidden_size).transpose(1,2)

TypeError: module() takes at most 2 arguments (3 given)

In [None]:
定义三个函数 forward、init、splithead

forward函数，将输入x变化为融合 注意力的x, 也就是输入的x的形状和输出的x的形状要是一样，这样可以在多层的注意力模块传递
    需要构造qkv，首先使用linear函数将x变成qkv的映射，对qkv的映射进行分头，对分头的qk进行注意力分数计算，需要将k转置需要d正则化，需要mask 乘以-1e-9然后softmax -1dim，然后matul乘以v的映射然后转置到与x形状相同，最后加线性层映射x为x形状
分头函数，需要将 num——head 和 head-dim分头分出来，-1自动计算seqlen