In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

from model import *
from util import *

In [6]:
def transpose_qkv(X, num_heads):
    X = X.view(X.shape[0], X.shape[1], num_heads, -1)
    X = X.transpose(2, 1).contiguous()
    output = X.view(-1, X.shape[2], X.shape[3])
    return output

def transpose_output(X, num_heads):
    X = X.view(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(2, 1).contiguous()
    return X.view(X.shape[0], X.shape[1], -1)

def handle_valid_length(valid_length, num_heads):
    if valid_length is not None:
        # Copy valid_length by num_heads times
        device = valid_length.device
        valid_length = valid_length.cpu().numpy() if valid_length.is_cuda else valid_length.numpy()
        if valid_length.ndim == 1:
            valid_length = torch.FloatTensor(np.tile(valid_length, num_heads))
        else:
            valid_length = torch.FloatTensor(np.tile(valid_length, (num_heads,1)))
        valid_length = valid_length.to(device)
    return valid_length

class MultiHeadAttention(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, dropout, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.wq = nn.Linear(input_size, hidden_size, bias=False)
        self.wk = nn.Linear(input_size, hidden_size, bias=False)
        self.wv = nn.Linear(input_size, hidden_size, bias=False)
        self.wo = nn.Linear(hidden_size, hidden_size, bias=False)
        
    def forward(self, query, key, value, valid_length):
        query = transpose_qkv(self.wq(query), self.num_heads)
        key      = transpose_qkv(self.wk(key), self.num_heads)
        value  = transpose_qkv(self.wv(value), self.num_heads)
        valid_length = handle_valid_length(valid_length, self.num_heads)
        output = self.attention(query, key, value, valid_length)
        output_concat = transpose_output(output, self.num_heads)
        return self.wo(output_concat)

In [7]:
cell = MultiHeadAttention(5, 9, 3, 0.5)
X = torch.ones((2, 4, 5))
valid_length = torch.FloatTensor([2, 3])
cell(X, X, X, valid_length).shape

torch.Size([2, 4, 9])


torch.Size([2, 4, 9])