In [1]:
import math
import torch
from torch import nn

In [37]:
from d2l_common import Module, DotProductAttention

class MultiHeadAttention(Module):
    def __init__(self, num_hiddens, num_heads, dropout=0.1, bias=False, **kwargs):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_heads = num_heads
        self.dropout = dropout
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
    
    def transpose_qkv(self, X):
        # shape: (batch_size, no. of qkv, num_heads, num_hiddens/num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        X = X.permute(0,2,1,3)
        # shape: (batch_size * num_heads, no. of q or kv, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])
    
    def transpose_output(self, X):
        # shape: (batch_size, num_heads, no. of qkv, num_hiddens / num_heads)
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        # shape: (batch_size, no. of qkv, num_heads, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # shape: (batch_size, no. of qkv, num_hiddens)
        return X.reshape(X.shape[0], X.shape[1], -1)
    
    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        
        # shape of output: (batch_size, num_heads, no. of qkv, num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # shape of output_concat: (batch_size, no. of qkv, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)
        

In [38]:
t = torch.randn(2, 4, 100)
t = t.view(2, 4, 5 ,-1) # (2,4,5,20)
t = t.permute(0,2,1,3) # (2,5,4,20)
t = t.reshape(2, 1, -1)
t.shape

torch.Size([2, 1, 400])

In [40]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor((3,2))
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))

output = attention(X, Y, Y, valid_lens)
assert output.shape == (batch_size, num_queries, num_hiddens)

tensor([[[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000

