### Bahdanau 注意力

Bahdanau 注意力就是将注意力机制应用于 seq to seq 模型上（机器翻译）

主要的想法是，在进行翻译时，我们翻译一个词可能只对应原文中的每一个小部分

在原本机器翻译模型上，解码器的隐藏状态初始值仍然使用编码器最后一层的隐藏状态来赋值

但是 context 不再使用最后一个时间步、最后一层隐藏状态，而是考虑最后一层每一个时间步的隐藏状态，权重由注意力机制确定

- query --> 解码器最后一层每一个时间步的隐藏状态

- key --> 最后一层每一个时间步的隐藏状态

- value --> 最后一层每一个时间步的隐藏状态

这里 key 就是 value，原文输入序列长度有多长，就有几个键值对

<br>

### 代码实现

In [4]:
from torch import nn
import torch

In [5]:
class AdditiveAttention(nn.Module):
    def __init__(self, query_size, key_size, hiden_size):
        super().__init__()
        self.linear_q = nn.Linear(query_size, hiden_size, bias=False)
        self.linear_k = nn.Linear(key_size, hiden_size, bias=False)
        self.dense = nn.Linear(hiden_size, 1, bias=False)
        self.softmax = nn.Softmax(dim=2)

    # query 形状：(batch_size, num_query, query_size)
    # key 形状：(batch_size, num_pair, key_size)
    # value 形状：(batch_size, num_pair, value_size)
    def forward(self, query, key, value):
        query_ = self.linear_q(query)   # (batch_size, num_query, hiden_size)
        key_ = self.linear_k(key)       # (batch_size, num_pair, hiden_size)
        H = torch.tanh(query_[:, :, None, :] + key_[:, None, :, :])      # (batch_size, num_query, num_pair, hiden_size)
        score = self.dense(H)                     # (batch_size, num_query, num_pair, 1)
        score = score.reshape(score.shape[0:3])   # (batch_size, num_query, num_pair)
        weight = self.softmax(score)              # (batch_size, num_query, num_pair)
        output = torch.bmm(weight, value)         # (batch_size, num_query, value_size)
        return output

编码器不做改变

In [6]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hiden_size, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, hiden_size, num_layers)

    # 输入 x 形状：(batch_size, seq_len)
    def forward(self, x):
        x = self.embedding(x).permute(1, 0, 2)
        output, state = self.rnn(x)
        # output 形状：(seq_len, batch_size, hiden_size)
        # state 形状：(num_layers, batch_size, hiden_size)
        return output, state

解码器添加注意力机制，改变每次与输入 cat 的 context

In [7]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hiden_size, num_layers):
        super().__init__()
        self.attention = AdditiveAttention(hiden_size, hiden_size, hiden_size)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + hiden_size, hiden_size, num_layers)
        self.dense = nn.Linear(hiden_size, vocab_size)

    # y 的形状：(batch_size, seq_len)
    # enc_outputs 就是编码器的返回值
    def forward(self, y, enc_outputs):
        output = enc_outputs[0]      # (seq_len, batch_size, hiden_size)，作为各个样本的键值对
        state = enc_outputs[1]       # (num_layers, batch_size, hiden_size)，作为解码器 RNN 的初始隐藏状态

        key_value = output.permute(1, 0, 2)   # (batch_size, num_pair, key_value_size)
        query = state[-1]     # (batch_size, query_size)，对于每一个样本只能一个一个查询来完成
        query = query.reshape(query.shape[0], 1, query.shape[1])    # (batch_size, num_query, query_size)，其中 num_query = 1

        y = self.embedding(y).permute(1, 0, 2)    # (seq_len, batch_size, embed_size)

        outputs = []

        # 对每个样本的每一个序列进行遍历
        for i in y:   # i 的形状：(batch_size, embed_size)
            context = self.attention(query, key_value, key_value)  # (batch_size, num_query, key_value_size)，其中 num_query = 1
            context = context.reshape(context.shape[0], context.shape[2])    # (batch_size, key_value_size)
            dec_input = torch.cat((i, context), dim=1)     # (batch_size, embed_size + hiden_size)
            dec_input = dec_input.reshape(1, dec_input.shape[0], dec_input.shape[1])   # (1, batch_size, embed_size + hiden_size)
            dec_outputs = self.rnn(dec_input, state)
            output = dec_outputs[0]    # (1, batch_size, hiden_size)，用于预测一个词元的隐藏状态
            state = dec_outputs[1]     # (num_layers, batch_size, hiden_size)，更新解码器 RNN 的隐藏状态
            query = output.permute(1, 0, 2)    # (batch_size, num_query, query_size)，其中 num_query = 1，进行下一轮查询
            outputs.append(output)

        outputs = torch.cat(outputs, dim=0)   # (seq_len, batch_size, hiden_size)
        return self.dense(outputs).permute(1, 0, 2)      # (batch_size, seq_len, vocab_size)