# 多头注意力



In [1]:
import sys
sys.path.append('..')

In [2]:
import math
import mindspore
import mindspore.numpy as mnp
import mindspore.nn as nn
import mindspore.ops as ops
from d2l import mindspore as d2l

选择缩放点积注意力作为每一个注意力头

In [3]:
class MultiHeadAttention(nn.Cell):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, has_bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(query_size, num_hiddens, has_bias=has_bias)
        self.W_k = nn.Dense(key_size, num_hiddens, has_bias=has_bias)
        self.W_v = nn.Dense(value_size, num_hiddens, has_bias=has_bias)
        self.W_o = nn.Dense(num_hiddens, num_hiddens, has_bias=has_bias)

    def construct(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = mnp.repeat(
                valid_lens, repeats=self.num_heads, axis=0)

        output = self.attention(queries, keys, values, valid_lens)

        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

使多个头并行计算

In [4]:
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状。"""
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    X = X.transpose(0, 2, 1, 3)

    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转 `transpose_qkv` 函数的操作。"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

测试

In [5]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.set_train(False)

MultiHeadAttention<
  (attention): DotProductAttention<
    (dropout): Dropout<keep_prob=0.5>
    >
  (W_q): Dense<input_channels=100, output_channels=100>
  (W_k): Dense<input_channels=100, output_channels=100>
  (W_v): Dense<input_channels=100, output_channels=100>
  (W_o): Dense<input_channels=100, output_channels=100>
  >

In [6]:
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, mindspore.Tensor([3, 2], mindspore.int32)
X = mnp.ones((batch_size, num_queries, num_hiddens))
Y = mnp.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

(2, 4, 100)