In [45]:
import torch
import torch.nn as nn
import math

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [24]:
def masked_softmax(X, valid_lens): # X.shape=[batch_size, num_queries, seq_len]
    def seq_mask(X, valid_len, value=-1e6):
        max_len = X.size(1)
        mask = torch.arange((max_len), dtype=torch.float32, 
                            device=X.device)[None, :]<valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
        # valid_lens is 1d so same lenght for a batch 
        # eg ([2,3] means batch 0 has valid_len 2 and batch 1 has valid_len 3)
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
        # different length for each query
            valid_lens = valid_lens.reshape(-1)
        X = seq_mask(X.reshape(-1, shape[-1]), valid_lens)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [30]:
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[2, 1],[4, 3]])))


tensor([[[0.5011, 0.4989, 0.0000, 0.0000],
         [0.5486, 0.4514, 0.0000, 0.0000]],

        [[0.1855, 0.3537, 0.4609, 0.0000],
         [0.2671, 0.4694, 0.2635, 0.0000]]])
tensor([[[0.3938, 0.6062, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]],

        [[0.2467, 0.2770, 0.2372, 0.2391],
         [0.2109, 0.2995, 0.4896, 0.0000]]])


In [63]:
class DotProductAttention(nn.Module):
    def __init__(self, dropout=0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, valid_lens=None):
        d = Q.shape[-1]
        scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), V)


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_hidden, num_heads, dropout=0, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hidden, bias=bias)
        self.W_k = nn.LazyLinear(num_hidden, bias=bias)
        self.W_v = nn.LazyLinear(num_hidden, bias=bias)
        self.W_o = nn.LazyLinear(num_hidden, bias=bias)
    
    # for parallel computation
    def transpose_qkv(self, X):
        # Shape of input X [batch_size, num_queries/kv, num_hidden]
        # Shape of output X [batch_size, num_queries/kv, num_heads, num_hidden/num_heads]
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)

        # Shape of output X [batch_size, num_heads, num_queries/kv, num_hidden/num_heads]
        X = X.permute(0, 2, 1, 3)
        # Shape of output X [batch_size*num_heads, num_queries/kv, num_hidden/num_heads]
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        # reverese the transpose_qkv
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)

    def forward(self, Q, K, V, valid_lens):
        # Shape of Q, K, V [batch_size, num_queries/kv, num_hidden]
        # After transposing shape of Q, K, V [batch_size * num_heads, num_queries/kv, num_hiddens / num_heads]
        Q = self.transpose_qkv(self.W_q(Q))
        K = self.transpose_qkv(self.W_k(K))
        V = self.transpose_qkv(self.W_v(V))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, self.num_heads)

        # Shape of output [batch_size * num_heads, num_queries, num_hiddens / num_heads]
        output = self.attention(Q, K, V, valid_lens)
        # Shape of output_concat [batch_size, num_queries, num_hidden]
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)


In [None]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
o = attention(X, Y, Y, valid_lens)
o.shape #should be [batch_size, num_queries, num_hiddens]


torch.Size([2, 4, 100])