In [1]:
import torch
import torch.nn
import torchtext
import copy

In [2]:
def assertEqual(t1, t2, message=""):
    assert torch.eq(t1, t2).all()
    print("Checking: ", message, "     ...OK!")

In [3]:
def _get_clones(module, N):
    return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [13]:
def get_percentage_equal(t1, t2):
    nonequal_elements = int((t1 != t2).sum())
    equal_elements = int((t1 == t2).sum())
    equal_percentage = equal_elements / (nonequal_elements + equal_elements)
    return equal_percentage

In [4]:
# d_head corresponds to d_k in the paper
def scaled_dot_product_attention(value, key, query, dropout=0.0):
    """
    Shape:
    - Inputs:
    - query: `(..., T, N * H, E / H)`
    - key: `(..., S, N * H, E / H)`
    - value: `(..., S, N * H, E /H)`
    
    where E = d_model, E/H = d_head
    
    - Outputs:
    - `(..., T, N * H, E / H)`, `(N * H, T, S)`
    """
    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "The d_head of query, key, value must be equal."
    S, T, N_H, d_head = key.shape[-3], query.shape[-3], query.shape[-2], query.shape[-1]
    
    query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)

    # calculates attention weights
    query = query * (float(d_head) ** -0.5)     
    attention_weights = torch.matmul(query, key.transpose(-2,-1))
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=dropout)
    assert attention_weights.shape == (N_H, T, S), "attention_weights should be shape (N * H, T, S)"

    attention_output = torch.matmul(attention_weights, value)

    return attention_output.transpose(-3, -2), attention_weights

In [5]:
def test_sdp():
    q = torch.randn(25, 256, 3)
    k = v = torch.randn(21, 256, 3)
    
    # call torchtext's SDP
    SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
    torch.manual_seed(42)
    expected_attn_output, expected_attn_weights = SDP(q, k, v)
    
    # call our SDP
    torch.manual_seed(42)    
    attn_output, attn_weights = scaled_dot_product_attention(v, k, q, dropout=0.1) 
    
    assert attn_weights.shape == expected_attn_weights.shape
    assertEqual(attn_weights, expected_attn_weights, message = "attention_weights are expected?")
    assert attn_output.shape == expected_attn_output.shape, "attn_output.shape is {0} whereas expected_output.shape is {1}".format(attn_output.shape, expected_attn_output.shape) 
    assertEqual(attn_output, expected_attn_output, message = "attention_output is expected?")
    
test_sdp()

Checking:  attention_weights are expected?      ...OK!
Checking:  attention_output is expected?      ...OK!


In [53]:
class Projection(torch.nn.Module):
    def __init__(self, d_model, W_v, W_k, W_q):
        super().__init__()
        self.W_v, self.W_k, self.W_q = W_v, W_k, W_q   
        
    def forward(self, attended, attending):
        # input dimension is (sentence_len, batch_size, d_model)
        value = self.W_v(attended)
        key = self.W_k(attended)
        query = self.W_q(attending)

        return value, key, query

In [54]:
def test_projection():
    from torchtext.nn import InProjContainer

    d_model, batch_size = 10, 64
    q = torch.rand((5, batch_size, d_model))
    k = v = torch.rand((6, batch_size, d_model))

    torch.manual_seed(42)
    l1 = torch.nn.Linear(d_model, d_model)
    l2 = l3 = copy.deepcopy(l1)
    
    # call torchtext's InProjContainer class
    in_proj_container = InProjContainer(l1, l2, l3)
    expected_query, expected_key, expected_value = in_proj_container(q, k, v)

    # call our Projection class
    projection = Projection(d_model, l1, l2, l3)
    value, key, query = projection(k, q)

    assert expected_query.shape == query.shape
    assert expected_key.shape == key.shape
    assert expected_value.shape == value.shape

    assertEqual(expected_query, query, message = "query is expected?")
    assertEqual(expected_key, key, message = "key is expected?")
    assertEqual(expected_value, value, message = "value is expected?")

test_projection()

Checking:  query is expected?      ...OK!
Checking:  key is expected?      ...OK!
Checking:  value is expected?      ...OK!


In [56]:
class MultiheadAttention(torch.nn.Module):
    def __init__(self, nheads, d_model, W_v, W_k, W_q, W_o, dropout = 0):
        super().__init__()
        self.nheads = nheads
        self.projection = Projection(d_model, W_v, W_k, W_q)
        self.W_v, self.W_k, self.W_q, self.W_o = W_v, W_k, W_q, W_o
        
    def forward(self, attended, attending, mask=None):
        '''
        Shape:
        - Inputs:
        - attending: :math:`(..., T, N, d_model)`
        - attended: :math:`(..., S, N, d_model)`

        - Outputs:
        - attn_output: :math:`(..., T, N, d_model)`
        - attn_output_weights: :math:`(N * H, T, S)`
        '''
        # checks dimensions & assigns variables
        assert attending.shape[-1] == attended.shape[-1], "attending & attended should have the same d_model"
        d_model = attending.shape[-1]
        assert d_model % self.nheads == 0, "d_model should be divisible by number of heads"
        self.d_k = d_model // self.nheads  
        assert attending.shape[-2] == attended.shape[-2], "attending & attended should have the same batch size"
        self.batch_size = attending.shape[-2]
        
        # projects attended and attending tensors to v, k, q
        value, key, query = self.projection(attended, attending)
        value, key, query = self.reshape_into_nheads(value, self.nheads, self.d_k), self.reshape_into_nheads(key, self.nheads, self.d_k), self.reshape_into_nheads(query, self.nheads, self.d_k)
                
        # forward multi-heads through SDP
        attn_output, attn_weights = scaled_dot_product_attention(value, key, query)
        assert attn_output.shape == (attending.shape[-3], self.batch_size * self.nheads, self.d_k), "attn_output's shape from SDP should be (..., T, N * H, E / H)"

        # concats multi-heads and forward through final layer
        attn_output = self.reshape_into_nheads(attn_output, 1, d_model)
        attn_output = self.W_o(attn_output)
        
        return attn_output, attn_weights
    
    def reshape_into_nheads(self, x, nheads, last_dim):
        return x.reshape(-1, self.batch_size * nheads, last_dim)

In [59]:
def test_multihead_attention():
    from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct
    from torch.nn.functional import multi_head_attention_forward as mha_forward

    d_model, num_heads, bsz = 10, 5, 64
    query = torch.rand((21, bsz, d_model))
    key = value = torch.rand((16, bsz, d_model))

    torch.manual_seed(42)
    l1 = torch.nn.Linear(d_model, d_model, bias=False)
    l2 = l3 = l4 = copy.deepcopy(l1)
    
    # call torchtext's MHA
    in_proj_container = InProjContainer(l1, l2, l3)
    torchtext_mha = MultiheadAttentionContainer(num_heads, 
                                      in_proj_container, 
                                      ScaledDotProduct(),
                                      l4)
    torchtext_attn_output, torchtext_attn_weights = torchtext_mha(query, key, value)
    
    # call torch's MHA
    in_proj_weight = torch.cat([torchtext_mha.in_proj_container.query_proj.weight,
                                    torchtext_mha.in_proj_container.key_proj.weight,
                                    torchtext_mha.in_proj_container.value_proj.weight])
    torch_attn_output, torch_attn_weights = mha_forward(query, key, value,
                            d_model, num_heads,
                            in_proj_weight, None,
                            None, None,
                            False, 0.0,
                            torchtext_mha.out_proj.weight, None)
    
    print(get_percentage_equal(torch_attn_output, torchtext_attn_output))
#     assertEqual(torch_attn_output, torchtext_attn_output)
    
    # call our MHA
    my_mha = MultiheadAttention(num_heads, d_model, l1, l2, l3, l4, dropout = 0.0)
    attn_output, attn_weights = my_mha(key, query)
    
    print(get_percentage_equal(torchtext_attn_output, attn_output))
    
    assert torchtext_attn_output.shape == attn_output.shape
    assert torchtext_attn_weights.shape == attn_weights.shape 
    assertEqual(torchtext_attn_output, attn_output, message = "attn_output expected?")
    assertEqual(torchtext_attn_weights, attn_weights, message = "attn_weights expected?")
    
test_multihead_attention()

0.6359375
1.0
Checking:  attn_output expected?      ...OK!
Checking:  attn_weights expected?      ...OK!


In [20]:
?torch.nn.Linear

[0;31mInit signature:[0m [0mtorch[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mLinear[0m[0;34m([0m[0min_features[0m[0;34m:[0m [0mint[0m[0;34m,[0m [0mout_features[0m[0;34m:[0m [0mint[0m[0;34m,[0m [0mbias[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

Args:
    in_features: size of each input sample
    out_features: size of each output sample
    bias: If set to ``False``, the layer will not learn an additive bias.
        Default: ``True``

Shape:
    - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
      additional dimensions and :math:`H_{in} = \text{in\_features}`
    - Output: :math:`(N, *, H_{out})` where all but the last dimension
      are the same shape as the input and :math:`H_{out} = \text{out\_featu

In [10]:
import torch
import torch.nn
from torch.nn import Dropout, LayerNorm, Linear

In [11]:
class FeedForward(torch.nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.1):
        super().__init__()
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        self.dropout = Dropout(dropout)
    
    def forward(self, x):
        x = self.dropout(torch.nn.functional.relu(self.linear1(x)))
        x = self.linear2(x)
        return x

In [12]:
class EncoderLayer(torch.nn.Module):
    def __init__(self, nheads, d_model, d_ff = 2048, dropout = 0.1, custom_encoder_layer = None):
        super().__init__()
        # sublayer1
        self.attention = MultiheadAttention(nheads, d_model, dropout = dropout) 
        self.dropout1 = Dropout(dropout)
        self.norm1 = LayerNorm(d_model)
        # sublayer2
        self.feed_forward = FeedForward(d_model, d_ff, dropout = dropout) 
        self.dropout2 = Dropout(dropout)
        self.norm2 = LayerNorm(d_model)
        
        # custom shit for testing
        if custom_encoder_layer != None:
            self.custom_mha_result = custom_encoder_layer.mha_result
            self.custom_feedforward_result = custom_encoder_layer.feedforward_result
            self.custom_sublayer1 = custom_encoder_layer.sublayer1
            self.custom_sublayer1_normalized = custom_encoder_layer.sublayer1_normalized
            self.custom_sublayer2_normalized = custom_encoder_layer.sublayer2_normalized
            
        
    def forward(self, src):
        # sublayer1 
        # res layer
        torch.manual_seed(42)
        if self.custom_mha_result != None:
            src = self.dropout1(self.custom_mha_result) + src 
        else:
            src = self.dropout1(self.attention(src, src)[0]) + src 
        # normalize
        src = self.norm1(src) 
        print("src vs. sublayer1_normalized: ", get_percentage_equal(src, self.custom_sublayer1_normalized))
        assertEqual(src, self.custom_sublayer1_normalized,  "src vs. sublayer1_normalized")

        # sublayer2
        # res layer
        if self.custom_feedforward_result != None:
            torch.manual_seed(42)
            src = self.dropout2(self.custom_feedforward_result) + src
        else:
            src = self.dropout2(self.feed_forward(src)) + src 
        # normalize
        src = self.norm2(src) # normalize
        assertEqual(src, self.custom_sublayer2_normalized,  "src vs. sublayer2_normalized")
        
        
        return src 

In [14]:
torch.manual_seed(42)
src_data = torch.rand(10, 32, 512)

In [15]:
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8, dropout=0.1)
out = encoder_layer(src_data)

In [16]:
def test_encoder_layer():
    # expected encoder layer
    encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8, dropout=0.1)
    out = encoder_layer(src_data)
    # my encoder layer
    my_encoder_layer = EncoderLayer(d_model=512, nheads=8, dropout=0.1, custom_encoder_layer = encoder_layer)
    my_out = my_encoder_layer(src_data)

    assert my_out.shape == out.shape
    print("percentage my output is similar to the expected", get_percentage_equal(my_out, out))
    assertEqual(my_out, out, "my encoder layer output is same as expected?")
    
test_encoder_layer()

ModuleAttributeError: 'TransformerEncoderLayer' object has no attribute 'mha_result'