In [23]:
import torch
import torchtext
import copy

In [24]:
def assertEqual(t1, t2, message=""):
    assert torch.eq(t1, t2).all()
    print("Checking: ", message, "     ...OK!")

In [25]:
def _get_clones(module, N):
    return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [81]:
# d_head corresponds to d_k in the paper
def scaled_dot_product_attention(value, key, query, dropout=0.0):
    """
    Shape:
    - Inputs:
    - query: `(..., T, N * H, E / H)`
    - key: `(..., S, N * H, E / H)`
    - value: `(..., S, N * H, E /H)`
    
    where E = d_model, E/H = d_head
    
    - Outputs:
    - `(..., T, N * H, E / H)`, `(N * H, T, S)`
    """
    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "The d_head of query, key, value must be equal."
    S, T, N_H, d_head = key.shape[-3], query.shape[-3], query.shape[-2], query.shape[-1]
    
    query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)

    # calculates attention weights
    query = query * (float(d_head) ** -0.5)     
    attention_weights = torch.matmul(query, key.transpose(-2,-1))
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=dropout)
    assert attention_weights.shape == (N_H, T, S), "attention_weights should be shape (N * H, T, S)"

    attention_output = torch.matmul(attention_weights, value)

    return attention_output.transpose(-3, -2), attention_weights

In [82]:
def test_sdp():
    q = torch.randn(25, 256, 3)
    k = v = torch.randn(21, 256, 3)
    
    # call torchtext's SDP
    SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
    expected_attn_output, expected_attn_weights = SDP(q, k, v)
    
    # call our SDP
    torch.manual_seed(42)    
    attn_output, attn_weights = scaled_dot_product_attention(v, k, q, dropout=0.1) 
    
    assert attn_weights.shape == expected_attn_weights.shape
    assertEqual(attn_weights, expected_attn_weights, message = "attention_weights are expected?")
    assert attn_output.shape == expected_attn_output.shape, "attn_output.shape is {0} whereas expected_output.shape is {1}".format(attn_output.shape, expected_attn_output.shape) 
    assertEqual(attn_output, expected_attn_output, message = "attention_output is expected?")
    
test_sdp()

Checking:  attention_weights are expected?      ...OK!
Checking:  attention_output is expected?      ...OK!


In [28]:
class Projection(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        torch.manual_seed(42)
        linear = torch.nn.Linear(d_model, d_model)
        self.linears = _get_clones(linear, 3)
        
        
    def forward(self, attended, attending):
        # input dimension is (sentence_len, batch_size, d_model)
        W_v, W_k, W_q = self.linears

        # forward
        value = W_v(attended)
        key = W_k(attended)
        query = W_q(attending)

        return value, key, query

In [33]:
def test_projection():
    from torchtext.nn import InProjContainer

    d_model, batch_size = 10, 64
    q = torch.rand((5, batch_size, d_model))
    k = v = torch.rand((6, batch_size, d_model))

    # call torchtext's InProjContainer class
    torch.manual_seed(42)
    l1 = torch.nn.Linear(d_model, d_model)
    torch.manual_seed(42)
    l2 = torch.nn.Linear(d_model, d_model)
    torch.manual_seed(42)
    l3 = torch.nn.Linear(d_model, d_model)
    in_proj_container = InProjContainer(l1, l2, l3)
    expected_query, expected_key, expected_value = in_proj_container(q, k, v)

    # call our Projection class
    projection = Projection(d_model)
    value, key, query = projection(k, q)

    assert expected_query.shape == query.shape
    assert expected_key.shape == key.shape
    assert expected_value.shape == value.shape

    assertEqual(expected_query, query, message = "query is expected?")
    assertEqual(expected_key, key, message = "key is expected?")
    assertEqual(expected_value, value, message = "key is expected?")

test_projection()

Checking:  query is expected?      ...OK!
Checking:  key is expected?      ...OK!
Checking:  key is expected?      ...OK!


In [58]:
class MultiheadAttention(torch.nn.Module):
    def __init__(self, nheads, d_model):
        super().__init__()
        self.nheads = nheads
        self.projection = Projection(d_model)
        
        torch.manual_seed(42)
        self.W_o = torch.nn.Linear(d_model, d_model)

    def forward(self, attended, attending, mask=None):
        '''
        Shape:
        - Inputs:
        - attending: :math:`(..., T, N, d_model)`
        - attended: :math:`(..., S, N, d_model)`

        - Outputs:
        - attn_output: :math:`(..., T, N, d_model)`
        - attn_output_weights: :math:`(N * H, T, S)`
        '''
        # checks dimensions & assigns variables
        assert attending.shape[-1] == attended.shape[-1], "attending & attended should have the same d_model"
        d_model = attending.shape[-1]
        assert d_model % self.nheads == 0, "d_model should be divisible by number of heads"
        self.d_k = d_model // self.nheads  
        assert attending.shape[-2] == attended.shape[-2], "attending & attended should have the same batch size"
        self.batch_size = attending.shape[-2]
        
        # projects attended and attending tensors to v, k, q
        value, key, query = self.projection(attended, attending)
        value, key, query = self.reshape_into_nheads(value, self.nheads, self.d_k), self.reshape_into_nheads(key, self.nheads, self.d_k), self.reshape_into_nheads(query, self.nheads, self.d_k)
                
        attn_output, attn_weights = scaled_dot_product_attention(value, key, query)
        assert attn_output.shape == (attending.shape[-3], self.batch_size * self.nheads, self.d_k), "attn_output's shape from SDP should be (..., T, N * H, E / H)"

        attn_output = self.reshape_into_nheads(attn_output, 1, d_model)
        attn_output = self.W_o(attn_output)
        
        return attn_output, attn_weights
    
    def reshape_into_nheads(self, x, nheads, last_dim):
        return x.reshape(-1, self.batch_size * nheads, last_dim)

In [83]:
from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct

d_model, num_heads, bsz = 10, 5, 64
query = torch.rand((21, bsz, d_model))
key = value = torch.rand((16, bsz, d_model))

# call torchtext's InProjContainer and then MHA
torch.manual_seed(42)
l1 = torch.nn.Linear(d_model, d_model)
torch.manual_seed(42)
l2 = torch.nn.Linear(d_model, d_model)
torch.manual_seed(42)
l3 = torch.nn.Linear(d_model, d_model)
in_proj_container = InProjContainer(l1, l2, l3)

torch.manual_seed(42)
l4 = torch.nn.Linear(d_model, d_model)
MHA = MultiheadAttentionContainer(num_heads, 
                                  in_proj_container, 
                                  ScaledDotProduct(),
                                  l4)

expected_attn_output, expected_attn_weights = MHA(query, key, value)

In [84]:
MY_MHA = MultiheadAttention(num_heads, d_model)
attn_output, attn_weights = MY_MHA(key, query)

In [85]:
assert expected_attn_output.shape == attn_output.shape
assert expected_attn_weights.shape == expected_attn_weights.shape

In [86]:
assertEqual(expected_attn_output, attn_output)
assertEqual(expected_attn_outpu, attn_output)

Checking:        ...OK!


In [77]:
expected_attn_output

tensor([[[ 0.0533,  0.0026,  0.1077,  ..., -0.2882, -0.4972, -0.3430],
         [ 0.0465,  0.0241,  0.0815,  ..., -0.3206, -0.4755, -0.3576],
         [ 0.1169,  0.0276,  0.1214,  ..., -0.3538, -0.5311, -0.3243],
         ...,
         [ 0.1029,  0.0102,  0.1307,  ..., -0.2922, -0.5349, -0.3298],
         [ 0.0462,  0.0226,  0.1029,  ..., -0.3630, -0.4679, -0.3299],
         [ 0.0536,  0.0316,  0.0972,  ..., -0.3049, -0.4821, -0.3343]],

        [[ 0.0548,  0.0070,  0.0987,  ..., -0.2948, -0.4959, -0.3506],
         [ 0.0459,  0.0252,  0.0811,  ..., -0.3238, -0.4733, -0.3561],
         [ 0.1115,  0.0290,  0.1223,  ..., -0.3584, -0.5244, -0.3222],
         ...,
         [ 0.1040,  0.0139,  0.1253,  ..., -0.2947, -0.5349, -0.3327],
         [ 0.0484,  0.0221,  0.1069,  ..., -0.3661, -0.4695, -0.3273],
         [ 0.0438,  0.0345,  0.0907,  ..., -0.3074, -0.4727, -0.3376]],

        [[ 0.0519,  0.0042,  0.1042,  ..., -0.2912, -0.4941, -0.3470],
         [ 0.0420,  0.0256,  0.0769,  ..., -0

In [78]:
attn_output

tensor([[[ 4.3185e-02,  2.9902e-02,  1.0467e-01,  ..., -2.8879e-01,
          -4.7713e-01, -3.2269e-01],
         [ 5.2500e-02,  6.7092e-02,  3.3881e-02,  ..., -2.6797e-01,
          -4.6646e-01, -3.6747e-01],
         [ 1.4379e-01,  2.8407e-02,  1.1836e-01,  ..., -3.2853e-01,
          -5.4493e-01, -3.3934e-01],
         ...,
         [ 8.1907e-02, -9.6690e-03,  1.2771e-01,  ..., -2.6678e-01,
          -5.2813e-01, -3.4869e-01],
         [ 3.1156e-02,  1.8001e-02,  1.1451e-01,  ..., -4.2149e-01,
          -4.5648e-01, -3.1561e-01],
         [ 6.9101e-02,  1.1207e-02,  1.0459e-01,  ..., -2.4799e-01,
          -5.1506e-01, -3.4459e-01]],

        [[ 4.8109e-02,  1.4327e-02,  8.6957e-02,  ..., -3.0937e-01,
          -4.9390e-01, -3.4983e-01],
         [ 4.5355e-02,  2.4181e-02,  6.5460e-02,  ..., -3.0616e-01,
          -4.7185e-01, -3.8532e-01],
         [ 1.3692e-01,  3.3717e-02,  1.3769e-01,  ..., -3.5911e-01,
          -5.4853e-01, -3.0835e-01],
         ...,
         [ 8.1717e-02,  7

In [79]:
print(torch.norm(expected_attn_output))
print(torch.norm(attn_output))

tensor(37.4668, grad_fn=<CopyBackwards>)
tensor(37.5427, grad_fn=<CopyBackwards>)
