## Annotated Multi-Head Attention

This is adapted from:
http://nlp.seas.harvard.edu/2018/04/03/attention.html#embeddings-and-softmax

Which was released under the MIT License, Copyright (c) 2018 Alexander Rush. 

Here, some things such as Attention Masks were removed. Also, the flag `linears=True` is added, and useful debugging prints are made to test the dimensions.

In [9]:
import copy
import math

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


def clones(module, N):
    # This function is adapted from:
    #     https://github.com/harvardnlp/annotated-transformer
    #     MIT License, Copyright (c) 2018 Alexander Rush
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def attention(query, key, value, dropout=None):
    # This function is adapted from:
    #     https://github.com/harvardnlp/annotated-transformer
    #     MIT License, Copyright (c) 2018 Alexander Rush
    
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    print("    key 1:", key.size())
    key = key.transpose(-2, -1)
    print("    key 2:", key.size())
    print("    query:", query.size())
    scores = torch.matmul(query, key) / math.sqrt(d_k)
    print("    scores:", scores.size())
    p_attn = F.softmax(scores, dim = -1)
    print("    p_attn:", p_attn.size())
    if dropout is not None:
        p_attn = dropout(p_attn)
    attention_result = torch.matmul(p_attn, value)
    print("    attention_result:", attention_result.size())
    return attention_result, p_attn


class MultiHeadedAttention(nn.Module):
    # This class is adapted from:
    #     https://github.com/harvardnlp/annotated-transformer
    #     MIT License, Copyright (c) 2018 Alexander Rush
    
    def __init__(self, h, hidden_size, linears=True, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert hidden_size % h == 0
        # We assume d_v always equals d_k
        self.d_k = hidden_size // h
        self.h = h
        if linears: 
            self.linears = clones(nn.Linear(hidden_size, hidden_size), 4)
        else:
            self.linears = [lambda arg: arg] * 4
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value):
        "Implements Figure 2"
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from hidden_size => h x d_k 
        print("query, key, value 1:", query.size(), key.size(), value.size())
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        print("query, key, value 2:", query.size(), key.size(), value.size())
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, self.dropout)
        print("x 1:", x.size())
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        print("x 2:", x.size())
        
        x = self.linears[-1](x)
        print("x 3:", x.size())
        return x

    
batch_size = 64
sequence_length = 10
hidden_size = 32
attention_heads = 8


mha = MultiHeadedAttention(h=attention_heads, hidden_size=hidden_size)


print("With as many attention queries as there are values:\n")
query = torch.tensor(np.ones([batch_size, 1, hidden_size])).float()
value = torch.tensor(np.ones([batch_size, sequence_length, hidden_size])).float()
result = mha.forward(query, value, value)
print("result:", result.size())
print("\n")

print("With a single attention query:\n")
query = torch.tensor(np.ones([batch_size, sequence_length, hidden_size])).float()
value = torch.tensor(np.ones([batch_size, sequence_length, hidden_size])).float()
result = mha.forward(query, value, value)
print("result:", result.size())

With as many attention queries as there are values:

query, key, value 1: torch.Size([64, 1, 32]) torch.Size([64, 10, 32]) torch.Size([64, 10, 32])
query, key, value 2: torch.Size([64, 8, 1, 4]) torch.Size([64, 8, 10, 4]) torch.Size([64, 8, 10, 4])
    key 1: torch.Size([64, 8, 10, 4])
    key 2: torch.Size([64, 8, 4, 10])
    query: torch.Size([64, 8, 1, 4])
    scores: torch.Size([64, 8, 1, 10])
    p_attn: torch.Size([64, 8, 1, 10])
    attention_result: torch.Size([64, 8, 1, 4])
x 1: torch.Size([64, 8, 1, 4])
x 2: torch.Size([64, 1, 32])
x 3: torch.Size([64, 1, 32])
result: torch.Size([64, 1, 32])


With a single attention query:

query, key, value 1: torch.Size([64, 10, 32]) torch.Size([64, 10, 32]) torch.Size([64, 10, 32])
query, key, value 2: torch.Size([64, 8, 10, 4]) torch.Size([64, 8, 10, 4]) torch.Size([64, 8, 10, 4])
    key 1: torch.Size([64, 8, 10, 4])
    key 2: torch.Size([64, 8, 4, 10])
    query: torch.Size([64, 8, 10, 4])
    scores: torch.Size([64, 8, 10, 10])
    p