In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable
import torch.nn.functional as F

torch.manual_seed(9)

%matplotlib inline

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, batch_size, nfeats):
        super(SelfAttention, self).__init__()
        
        self.nfeats = nfeats
        # Wq, Wk, Wv (more controllable parameters)
        self.Wq = Variable(torch.randn(batch_size, nfeats, nfeats), requires_grad=True)
        self.Wk = Variable(torch.randn(batch_size, nfeats, nfeats), requires_grad=True)
        self.Wv = Variable(torch.randn(batch_size, nfeats, nfeats), requires_grad=True)
        
    def forward(self, X):
        Q = torch.bmm(self.Wq, X)
        K = torch.bmm(self.Wk, X)
        V = torch.bmm(self.Wv, X)
        scale = 1.0 / np.sqrt(self.nfeats)
        W = torch.bmm(Q.transpose(2, 1), K) * scale
        return torch.bmm(F.softmax(W, dim=2), V)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, r, batch_size, nfeats):
        super(MultiHeadAttention, self).__init__()
        
        self.multihead = [ SelfAttention(batch_size, nfeats) for _ in range(r) ]
        self.transform = nn.Linear(r * nfeats, nfeats)
        
    def forward(self, X):
        Y = torch.cat([ self.multihead[idx](X) for idx in range(len(self.multihead)) ], dim=2)
        return self.transform(Y)

In [3]:
x1 = torch.FloatTensor([1, 2])
x2 = torch.FloatTensor([3, 2])
X = torch.stack([x1, x2]).unsqueeze(0)
print(X.size())

multihead_attention = MultiHeadAttention(2, 1, 2)
print(multihead_attention(X))

torch.Size([1, 2, 2])
tensor([[[ 0.4898, -0.9724],
         [ 0.4850, -0.8700]]], grad_fn=<AddBackward0>)
