In [91]:
import torch
import torchtext
import copy

In [92]:
def assertEqual(t1, t2, message=""):
    assert torch.eq(t1, t2).all()
    print("Checking: ", message, "     ...OK!")

In [93]:
def _get_clones(module, N):
    return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [94]:
# d_head corresponds to d_k in the paper
def scaled_dot_product_attention(value, key, query, dropout=0.0):
    """
    Shape:
    - Inputs:
    - query: `(..., T, N * H, E / H)`
    - key: `(..., S, N * H, E / H)`
    - value: `(..., S, N * H, E /H)`
    
    where E = d_model, E/H = d_head
    
    - Outputs:
    - `(..., T, N * H, E / H)`, `(N * H, T, S)`
    """
    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "The d_head of query, key, value must be equal."
    S, T, N_H, d_head = key.shape[-3], query.shape[-3], query.shape[-2], query.shape[-1]
    
    query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)

    # calculates attention weights
    query = query * (float(d_head) ** -0.5)     
    attention_weights = torch.matmul(query, key.transpose(-2,-1))
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=dropout)
    assert attention_weights.shape == (N_H, T, S), "attention_weights should be shape (N * H, T, S)"

    attention_output = torch.matmul(attention_weights, value)

    return attention_output.transpose(-3, -2), attention_weights

In [95]:
def test_sdp():
    q = torch.randn(25, 256, 3)
    k = v = torch.randn(21, 256, 3)
    
    # call torchtext's SDP
    SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
    expected_attn_output, expected_attn_weights = SDP(q, k, v)
    
    # call our SDP
    torch.manual_seed(42)    
    attn_output, attn_weights = scaled_dot_product_attention(v, k, q, dropout=0.1) 
    
    assert attn_weights.shape == expected_attn_weights.shape
    assertEqual(attn_weights, expected_attn_weights, message = "attention_weights are expected?")
    assert attn_output.shape == expected_attn_output.shape, "attn_output.shape is {0} whereas expected_output.shape is {1}".format(attn_output.shape, expected_attn_output.shape) 
    assertEqual(attn_output, expected_attn_output, message = "attention_output is expected?")
    
test_sdp()

Checking:  attention_weights are expected?      ...OK!
Checking:  attention_output is expected?      ...OK!


In [96]:
class Projection(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        torch.manual_seed(42)
        linear = torch.nn.Linear(d_model, d_model)
        self.linears = _get_clones(linear, 3)
        
        
    def forward(self, attended, attending):
        # input dimension is (sentence_len, batch_size, d_model)
        W_v, W_k, W_q = self.linears

        value = W_v(attended)
        key = W_k(attended)
        query = W_q(attending)

        return value, key, query

In [98]:
def test_projection():
    from torchtext.nn import InProjContainer

    d_model, batch_size = 10, 64
    q = torch.rand((5, batch_size, d_model))
    k = v = torch.rand((6, batch_size, d_model))

    # call torchtext's InProjContainer class
    torch.manual_seed(42)
    l1 = torch.nn.Linear(d_model, d_model)
    l2 = l3 = copy.deepcopy(l1)
    in_proj_container = InProjContainer(l1, l2, l3)
    expected_query, expected_key, expected_value = in_proj_container(q, k, v)

    # call our Projection class
    projection = Projection(d_model)
    value, key, query = projection(k, q)

    assert expected_query.shape == query.shape
    assert expected_key.shape == key.shape
    assert expected_value.shape == value.shape

    assertEqual(expected_query, query, message = "query is expected?")
    assertEqual(expected_key, key, message = "key is expected?")
    assertEqual(expected_value, value, message = "value is expected?")

test_projection()

Checking:  query is expected?      ...OK!
Checking:  key is expected?      ...OK!
Checking:  value is expected?      ...OK!


In [99]:
class MultiheadAttention(torch.nn.Module):
    def __init__(self, nheads, d_model):
        super().__init__()
        self.nheads = nheads
        self.projection = Projection(d_model)
        
        torch.manual_seed(42)
        self.W_o = torch.nn.Linear(d_model, d_model)

    def forward(self, attended, attending, mask=None):
        '''
        Shape:
        - Inputs:
        - attending: :math:`(..., T, N, d_model)`
        - attended: :math:`(..., S, N, d_model)`

        - Outputs:
        - attn_output: :math:`(..., T, N, d_model)`
        - attn_output_weights: :math:`(N * H, T, S)`
        '''
        # checks dimensions & assigns variables
        assert attending.shape[-1] == attended.shape[-1], "attending & attended should have the same d_model"
        d_model = attending.shape[-1]
        assert d_model % self.nheads == 0, "d_model should be divisible by number of heads"
        self.d_k = d_model // self.nheads  
        assert attending.shape[-2] == attended.shape[-2], "attending & attended should have the same batch size"
        self.batch_size = attending.shape[-2]
        
        # projects attended and attending tensors to v, k, q
        value, key, query = self.projection(attended, attending)
        value, key, query = self.reshape_into_nheads(value, self.nheads, self.d_k), self.reshape_into_nheads(key, self.nheads, self.d_k), self.reshape_into_nheads(query, self.nheads, self.d_k)
                
        # forward multi-heads through SDP
        attn_output, attn_weights = scaled_dot_product_attention(value, key, query)
        assert attn_output.shape == (attending.shape[-3], self.batch_size * self.nheads, self.d_k), "attn_output's shape from SDP should be (..., T, N * H, E / H)"

        # concats multi-heads and forward through final layer
        attn_output = self.reshape_into_nheads(attn_output, 1, d_model)
        attn_output = self.W_o(attn_output)
        
        return attn_output, attn_weights
    
    def reshape_into_nheads(self, x, nheads, last_dim):
        return x.reshape(-1, self.batch_size * nheads, last_dim)

In [100]:
def test_multihead_attention():
    from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct

    d_model, num_heads, bsz = 10, 5, 64
    query = torch.rand((21, bsz, d_model))
    key = value = torch.rand((16, bsz, d_model))

    # call torchtext's InProjContainer and then MHA
    torch.manual_seed(42)
    l1 = torch.nn.Linear(d_model, d_model)
    l2 = l3 = l4 = copy.deepcopy(l1)
    
    in_proj_container = InProjContainer(l1, l2, l3)
    MHA = MultiheadAttentionContainer(num_heads, 
                                      in_proj_container, 
                                      ScaledDotProduct(),
                                      l4)
    expected_attn_output, expected_attn_weights = MHA(query, key, value)
    
    # call our MHA
    MY_MHA = MultiheadAttention(num_heads, d_model)
    attn_output, attn_weights = MY_MHA(key, query)
    
    assert expected_attn_output.shape == attn_output.shape
    assert expected_attn_weights.shape == expected_attn_weights.shape
    assertEqual(expected_attn_output, attn_output, message = "attn_output expected?")
    assertEqual(expected_attn_weights, attn_weights, message = "attn_weights expected?")

test_multihead_attention()

Checking:  attn_output expected?      ...OK!
Checking:  attn_weights expected?      ...OK!
