In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
import torch.nn.functional as F
import fairseq

In [2]:
fairseq.__path__

['/env_nlp/lib/python3.9/site-packages/fairseq']

## Load dataset from fairseq

In [3]:
fairseq.tasks

<module 'fairseq.tasks' from '/env_nlp/lib/python3.9/site-packages/fairseq/tasks/__init__.py'>

In [4]:
fairseq.tasks.TASK_REGISTRY

{'multilingual_masked_lm': fairseq.tasks.multilingual_masked_lm.MultiLingualMaskedLMTask,
 'translation': fairseq.tasks.translation.TranslationTask,
 'translation_lev': fairseq.tasks.translation_lev.TranslationLevenshteinTask,
 'translation_multi_simple_epoch': fairseq.tasks.translation_multi_simple_epoch.TranslationMultiSimpleEpochTask,
 'speech_unit_modeling': fairseq.tasks.speech_ulm_task.SpeechUnitLanguageModelingTask,
 'hubert_pretraining': fairseq.tasks.hubert_pretraining.HubertPretrainingTask,
 'multilingual_translation': fairseq.tasks.multilingual_translation.MultilingualTranslationTask,
 'language_modeling': fairseq.tasks.language_modeling.LanguageModelingTask,
 'masked_lm': fairseq.tasks.masked_lm.MaskedLMTask,
 'audio_pretraining': fairseq.tasks.audio_pretraining.AudioPretrainingTask,
 'audio_finetuning': fairseq.tasks.audio_finetuning.AudioFinetuningTask,
 'multilingual_language_modeling': fairseq.tasks.multilingual_language_modeling.MultilingualLanguageModelingTask,
 'spee

In [5]:
torch.manual_seed(0)
VOCAB_SIZE = 40000
SRC_SIZE = 256
BATCH_SIZE = 4
PAD_IDX = 0
src = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SRC_SIZE))
src

tensor([[16044,  8239, 34933,  ...,  2374, 13973, 36703],
        [23757, 31943, 33848,  ..., 22762, 16629, 23216],
        [ 8606,   763, 25616,  ...,  8786, 32391, 13916],
        [ 7937, 25449, 14018,  ...,  3122, 12978, 14438]])

In [6]:
mask = torch.randint(8, SRC_SIZE, ( BATCH_SIZE, 1))
cond = torch.arange(SRC_SIZE, dtype=mask.dtype)
src[cond > mask] = 0
src

tensor([[16044,  8239, 34933,  ...,     0,     0,     0],
        [23757, 31943, 33848,  ...,     0,     0,     0],
        [ 8606,   763, 25616,  ...,     0,     0,     0],
        [ 7937, 25449, 14018,  ...,     0,     0,     0]])

In [7]:
src[0]

tensor([16044,  8239, 34933, 13760,  8963,  8379, 15427, 38503, 23497, 25683,
        14101, 26866, 22756, 21399, 15878, 20376, 20056,  9868, 28794, 36033,
        33126, 38119, 26391,  6254, 12824, 37841,  7269, 10969, 17549, 25480,
        17481, 37012, 37363, 25360, 34995, 30085, 37822,  1999, 33711, 36288,
        13778, 11043, 39896,  5489, 38771, 33757, 27283,  5395, 24782, 13671,
        39840, 35969, 17673,   441,  2111, 29280,  1503, 29706, 39476, 18327,
        17699, 11326, 28363, 39474, 30775, 36700,  5218, 37632, 21568,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [8]:
mask[0]

tensor([68])

In [9]:
src[0][mask[0]:]

tensor([21568,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [10]:
PAD_IDX, UNK_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 

In [11]:
mask

tensor([[ 68],
        [187],
        [ 70],
        [232]])

In [12]:
src[ ..., mask.flatten()]

tensor([[21568,     0,     0,     0],
        [ 5210, 29448,  3756,     0],
        [18816,     0, 22200,     0],
        [ 1261, 20180, 32506, 35495]])

In [13]:
src[ -1, mask.flatten()]

tensor([ 1261, 20180, 32506, 35495])

In [14]:
src[..., mask]

tensor([[[21568],
         [    0],
         [    0],
         [    0]],

        [[ 5210],
         [29448],
         [ 3756],
         [    0]],

        [[18816],
         [    0],
         [22200],
         [    0]],

        [[ 1261],
         [20180],
         [32506],
         [35495]]])

In [15]:
src[:, mask]

tensor([[[21568],
         [    0],
         [    0],
         [    0]],

        [[ 5210],
         [29448],
         [ 3756],
         [    0]],

        [[18816],
         [    0],
         [22200],
         [    0]],

        [[ 1261],
         [20180],
         [32506],
         [35495]]])

In [16]:
src[ torch.arange(BATCH_SIZE), mask.flatten()]

tensor([21568, 29448, 22200, 35495])

In [17]:
def get_dummy_input(batch_size=BATCH_SIZE):
    src = torch.randint(0, VOCAB_SIZE, (batch_size, SRC_SIZE))
    mask_idx = torch.randint(5, SRC_SIZE, (batch_size, 1))
    mask = torch.arange(SRC_SIZE, dtype=mask_idx.dtype) > mask_idx
    src[mask] = PAD_IDX
    src[torch.arange(batch_size), mask_idx.flatten()] = EOS_IDX 
    src[torch.arange(batch_size), [0] * batch_size] = BOS_IDX 
    return src
    

In [18]:
inp_src = get_dummy_input()
inp_src.shape

torch.Size([4, 256])

In [19]:
inp_src[1]

tensor([    2, 15972, 15186, 33687,  7189, 11198, 33410, 26120, 11018, 17925,
        11110, 23014, 10660, 39489, 39925, 27707, 15524,  9861,  7877,  7718,
        39587, 14424, 38577, 16335, 31572, 27186,  5577, 23159, 21380, 15793,
        23256,  7198, 12418,    83, 26115, 27085,  4939, 38541,  3492, 17425,
        26522, 23430, 18799,  7458, 36629,  6228, 20074,  5946, 10105, 27729,
        33061, 15769,  5364, 38293,  6839,  8292, 24068, 16047, 35618, 30603,
        16354, 21665, 28910, 13011, 10032,  2379, 22497, 14553, 17121,  4575,
        26733, 34502, 15577, 24152, 14701,  9446, 26291, 22800, 29143, 13978,
        35651, 36058,  2862,  6714,   669, 26882, 33605, 33209, 16135, 24079,
        26510,   547,  4596, 19943, 22198, 38599, 14579,  5990, 11129, 26483,
        10397,  4215, 13105, 18763, 32515,  5837, 12026,  4727, 38680,   635,
        12220, 33208, 36484, 21749, 19059,  2134, 30329, 15965, 32806, 10379,
         2318, 35858,  7641, 10860, 10499,  2917, 37127, 16797, 

In [20]:
inp_src[0]

tensor([    2,  7725, 16233, 14146, 20513, 24204, 22223,  1718, 39329, 34007,
         4556,  5061, 39462,  8961, 25738, 33286, 19453, 28166, 33446, 39014,
        21963, 28594,  4437, 24382, 39986, 31346, 38579, 13954, 16140,  1053,
        29001, 25267, 24392, 21381, 22964, 18421, 22059, 36993, 27691, 39631,
        29182,  5964, 21615, 30156, 38485, 17912,  1938, 23257, 33933, 31326,
         7068, 12908, 10217, 38229,  8454,  8292, 25008, 22568, 17370, 22046,
        27013, 35983, 23295,  6565,  9253,  6459,  3185,  7711,  5216, 28528,
         8208, 23584, 32638, 37450, 15180, 35101,  8576,   989, 18458, 14375,
           86, 18387, 31237,  9252, 19603, 22975, 31780, 15173, 30236, 13452,
        27546, 30193, 38163, 11980, 12646, 31879, 39604,  2576, 11229, 30656,
         9239,  1935, 14410, 28284, 12803, 24701,  9347, 12368, 22064,  7817,
        38484, 22260, 37093, 38591,  4088, 30846,    56, 25264,   721, 18869,
        21753, 15113, 14380,  1790,  3279, 20198,  9550, 38689, 

## Embedding

In [21]:
HIDDEN_SIZE = 2048
EMBED_SIZE = 1024

In [22]:
def Embedding(voc_size, embed_dim):
    embedding = nn.Embedding(num_embeddings=voc_size, embedding_dim=embed_dim,
                             padding_idx=PAD_IDX, )
    nn.init.uniform_(embedding.weight, -0.1, 0.1)
    nn.init.constant_(embedding.weight[PAD_IDX], 0.)
    return embedding


In [23]:
embedding_net = Embedding(VOCAB_SIZE, embed_dim=EMBED_SIZE)

In [24]:
embedding_net.weight

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0256,  0.0326,  0.0972,  ..., -0.0654, -0.0206, -0.0127],
        [-0.0246, -0.0414,  0.0920,  ...,  0.0560,  0.0799, -0.0465],
        ...,
        [ 0.0165,  0.0210,  0.0956,  ..., -0.0821,  0.0432, -0.0001],
        [-0.0118, -0.0252, -0.0988,  ...,  0.0664, -0.0533,  0.0143],
        [ 0.0847,  0.0929,  0.0016,  ..., -0.0736, -0.0177, -0.0841]],
       requires_grad=True)

In [25]:
embedding_net.weight.shape

torch.Size([40000, 1024])

In [26]:
em = embedding_net(src)

In [27]:
# batch_size, seq_size, hidden_size
em.shape

torch.Size([4, 256, 1024])

In [28]:
embedding_net.weight[1].data.copy_(1.0)

tensor([1., 1., 1.,  ..., 1., 1., 1.])

In [29]:
embedding_net.weight

Parameter containing:
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  ...,  1.0000e+00,
          1.0000e+00,  1.0000e+00],
        [-2.4572e-02, -4.1433e-02,  9.2018e-02,  ...,  5.6014e-02,
          7.9891e-02, -4.6524e-02],
        ...,
        [ 1.6473e-02,  2.0970e-02,  9.5622e-02,  ..., -8.2123e-02,
          4.3167e-02, -1.1914e-04],
        [-1.1751e-02, -2.5235e-02, -9.8790e-02,  ...,  6.6410e-02,
         -5.3273e-02,  1.4340e-02],
        [ 8.4664e-02,  9.2933e-02,  1.5503e-03,  ..., -7.3641e-02,
         -1.7708e-02, -8.4070e-02]], requires_grad=True)

In [30]:
embedding_net.weight.shape, src.shape

(torch.Size([40000, 1024]), torch.Size([4, 256]))

In [31]:
embedding_net.weight.index_select(0, src.flatten()).view(BATCH_SIZE, SRC_SIZE, EMBED_SIZE).shape

torch.Size([4, 256, 1024])

In [32]:
src

tensor([[16044,  8239, 34933,  ...,     0,     0,     0],
        [23757, 31943, 33848,  ...,     0,     0,     0],
        [ 8606,   763, 25616,  ...,     0,     0,     0],
        [ 7937, 25449, 14018,  ...,     0,     0,     0]])

In [33]:
encoder_embed = Embedding(VOCAB_SIZE, EMBED_SIZE)
decoder_embed = Embedding(VOCAB_SIZE, EMBED_SIZE)

In [34]:
encoder_embed.weight.requires_grad

True

In [35]:
FREEZE = False

In [36]:
if FREEZE:
    encoder_embed.weight.requires_grad = False
    decoder_embed.weight.requires_grad = False

## Encoder

In [37]:
def LSTM(embed_size, hidden_size, num_layers, bidirectional=False, dropout=0., batch_first=True, bias=True):
    lstm = nn.LSTM(input_size=embed_size,
                   hidden_size=hidden_size,
                   num_layers=num_layers,
                   bidirectional=bidirectional,
                   dropout=dropout,
                   batch_first=batch_first,
                   bias=bias
                   )
    for name, param in lstm.named_parameters():
        if "weight" in name or "bias" in name:
            nn.init.uniform_(param.data, -0.1, 0.1)
            
    return lstm

In [38]:
lstm = LSTM(1024, 2048, 2)

In [39]:
lstm.weight_hh_l0

Parameter containing:
tensor([[-0.0839, -0.0312,  0.0732,  ..., -0.0152, -0.0167,  0.0414],
        [-0.0888, -0.0119,  0.0562,  ...,  0.0124,  0.0246, -0.0816],
        [-0.0145, -0.0381, -0.0945,  ...,  0.0341,  0.0032,  0.0858],
        ...,
        [-0.0844,  0.0543, -0.0329,  ..., -0.0752,  0.0890,  0.0523],
        [-0.0571,  0.0174, -0.0177,  ..., -0.0368,  0.0885,  0.0769],
        [-0.0156,  0.0214,  0.0489,  ..., -0.0976,  0.0896, -0.0419]],
       requires_grad=True)

In [40]:
lstm.weight_hh_l0.shape, lstm.weight_hh_l1.shape

(torch.Size([8192, 2048]), torch.Size([8192, 2048]))

In [41]:
lstm.weight_ih_l0.shape, lstm.weight_ih_l1.shape

(torch.Size([8192, 1024]), torch.Size([8192, 2048]))

In [42]:
lstm.bias_ih_l0.shape, lstm.bias_ih_l1.shape

(torch.Size([8192]), torch.Size([8192]))

In [43]:
lstm.weight_ih_l0.shape, lstm.weight_ih_l1.shape, 

(torch.Size([8192, 1024]), torch.Size([8192, 2048]))

In [44]:
lstm(embedding_net(src))[0].shape

torch.Size([4, 256, 2048])

In [45]:
out, (hn, cn) = lstm(embedding_net(src))

In [46]:
out.shape

torch.Size([4, 256, 2048])

In [47]:
hn.shape

torch.Size([2, 4, 2048])

In [48]:
cn.shape

torch.Size([2, 4, 2048])

In [49]:
lstm = LSTM(1024, 2048, 4)

In [50]:
out, (hn, cn) = lstm(embedding_net(src))
out.shape, hn.shape, cn.shape

(torch.Size([4, 256, 2048]),
 torch.Size([4, 4, 2048]),
 torch.Size([4, 4, 2048]))

In [51]:
lstm.weight_ih_l3.shape, lstm.weight_ih_l3.shape

(torch.Size([8192, 2048]), torch.Size([8192, 2048]))

In [52]:
lstm = LSTM(1024, 2048, 3)

In [53]:
h0 = torch.randn(2048)
h0 = h0.expand(3, 4, -1) #   dim 0 -> num_layers * bidirectional, dim 1 -> batch_size,
h0.shape

torch.Size([3, 4, 2048])

In [54]:
c0 = torch.randn(2048)
c0 = c0.expand( 3, 4, -1) 
c0.shape

torch.Size([3, 4, 2048])

In [55]:
out, (hn, cn) = lstm(embedding_net(src), (h0, c0))

In [56]:
out.shape

torch.Size([4, 256, 2048])

In [57]:
hn.shape, hn.shape, cn.shape

(torch.Size([3, 4, 2048]), torch.Size([3, 4, 2048]), torch.Size([3, 4, 2048]))

In [58]:
# Bidirectional
lstm = LSTM(1024, 2048, 3, bidirectional=True)

In [59]:
lstm.weight_hh_l0_reverse

Parameter containing:
tensor([[ 0.0164,  0.0693,  0.0390,  ..., -0.0925,  0.0564,  0.0884],
        [-0.0491,  0.0902,  0.0821,  ..., -0.0888,  0.0981, -0.0451],
        [-0.0084, -0.0034, -0.0749,  ...,  0.0713,  0.0714,  0.0286],
        ...,
        [-0.0234,  0.1000, -0.0939,  ..., -0.0311,  0.0420, -0.0901],
        [ 0.0302, -0.0847,  0.0112,  ..., -0.0763,  0.0141,  0.0681],
        [-0.0454,  0.0429, -0.0436,  ..., -0.0012, -0.0645, -0.0658]],
       requires_grad=True)

In [60]:
lstm.weight_hh_l0_reverse.shape

torch.Size([8192, 2048])

In [61]:
lstm.weight_ih_l0.shape, lstm.weight_ih_l1.shape, lstm.weight_ih_l2.shape

(torch.Size([8192, 1024]), torch.Size([8192, 4096]), torch.Size([8192, 4096]))

In [62]:
lstm.weight_ih_l0_reverse.shape, lstm.weight_ih_l1_reverse.shape, lstm.weight_ih_l2_reverse.shape

(torch.Size([8192, 1024]), torch.Size([8192, 4096]), torch.Size([8192, 4096]))

In [63]:
len(lstm.all_weights)

6

In [64]:
len(lstm.all_weights[0]), type(lstm.all_weights[0]) # Wih, Wif, Wig, Wio

(4, list)

In [65]:
lstm.all_weights[0][0], lstm.weight_ih_l0[:, :1024]

(Parameter containing:
 tensor([[ 0.0270, -0.0652, -0.0140,  ..., -0.0091,  0.0693,  0.0468],
         [ 0.0362,  0.0413, -0.0744,  ...,  0.0002,  0.0516, -0.0539],
         [ 0.0288, -0.0080, -0.0532,  ..., -0.0147,  0.0828,  0.0289],
         ...,
         [-0.0231,  0.0375, -0.0368,  ..., -0.0428, -0.0488, -0.0843],
         [ 0.0663,  0.0398, -0.0229,  ...,  0.0997, -0.0129, -0.0330],
         [ 0.0962, -0.0518, -0.0279,  ..., -0.0304,  0.0987,  0.0815]],
        requires_grad=True),
 tensor([[ 0.0270, -0.0652, -0.0140,  ..., -0.0091,  0.0693,  0.0468],
         [ 0.0362,  0.0413, -0.0744,  ...,  0.0002,  0.0516, -0.0539],
         [ 0.0288, -0.0080, -0.0532,  ..., -0.0147,  0.0828,  0.0289],
         ...,
         [-0.0231,  0.0375, -0.0368,  ..., -0.0428, -0.0488, -0.0843],
         [ 0.0663,  0.0398, -0.0229,  ...,  0.0997, -0.0129, -0.0330],
         [ 0.0962, -0.0518, -0.0279,  ..., -0.0304,  0.0987,  0.0815]],
        grad_fn=<SliceBackward0>))

In [66]:
(lstm.all_weights[0][0]  - lstm.weight_ih_l0[:, :1024]).sum()

tensor(0., grad_fn=<SumBackward0>)

In [67]:
len(lstm.all_weights[2]), len(lstm.all_weights[-1])

(4, 4)

In [68]:
lstm.all_weights[1][0].shape, lstm.weight_ih_l0_reverse.shape

(torch.Size([8192, 1024]), torch.Size([8192, 1024]))

In [69]:
lstm.all_weights[1][0], lstm.weight_ih_l0_reverse[:, :1024]

(Parameter containing:
 tensor([[ 0.0386, -0.0226,  0.0679,  ..., -0.0732,  0.0101, -0.0297],
         [ 0.0937,  0.0623,  0.0302,  ...,  0.0047, -0.0335,  0.0184],
         [ 0.0924, -0.0607, -0.0399,  ...,  0.0347, -0.0368,  0.0720],
         ...,
         [-0.0131,  0.0256,  0.0554,  ..., -0.0180, -0.0110,  0.0932],
         [ 0.0628, -0.0708,  0.0190,  ..., -0.0954,  0.0142,  0.0919],
         [ 0.0825,  0.0975, -0.0044,  ..., -0.0327,  0.0937,  0.0865]],
        requires_grad=True),
 tensor([[ 0.0386, -0.0226,  0.0679,  ..., -0.0732,  0.0101, -0.0297],
         [ 0.0937,  0.0623,  0.0302,  ...,  0.0047, -0.0335,  0.0184],
         [ 0.0924, -0.0607, -0.0399,  ...,  0.0347, -0.0368,  0.0720],
         ...,
         [-0.0131,  0.0256,  0.0554,  ..., -0.0180, -0.0110,  0.0932],
         [ 0.0628, -0.0708,  0.0190,  ..., -0.0954,  0.0142,  0.0919],
         [ 0.0825,  0.0975, -0.0044,  ..., -0.0327,  0.0937,  0.0865]],
        grad_fn=<SliceBackward0>))

In [70]:
lstm.all_weights[2][0].shape, lstm.weight_ih_l1.shape

(torch.Size([8192, 4096]), torch.Size([8192, 4096]))

In [71]:
lstm.all_weights[2][0], lstm.weight_ih_l1[:, :1024]

(Parameter containing:
 tensor([[ 0.0657, -0.0548,  0.0271,  ...,  0.0324, -0.0260, -0.0590],
         [ 0.0863,  0.0830,  0.0494,  ...,  0.0477, -0.0521, -0.0444],
         [ 0.0781,  0.0236, -0.0523,  ...,  0.0738,  0.0493, -0.0803],
         ...,
         [-0.0906, -0.0942,  0.0503,  ..., -0.0435,  0.0104,  0.0342],
         [-0.0487, -0.0549, -0.0973,  ..., -0.0448, -0.0126, -0.0669],
         [-0.0916, -0.0943,  0.0814,  ...,  0.0107, -0.0140, -0.0919]],
        requires_grad=True),
 tensor([[ 0.0657, -0.0548,  0.0271,  ..., -0.0245,  0.0995,  0.0231],
         [ 0.0863,  0.0830,  0.0494,  ...,  0.0550, -0.0011,  0.0104],
         [ 0.0781,  0.0236, -0.0523,  ..., -0.0888, -0.0014,  0.0096],
         ...,
         [-0.0906, -0.0942,  0.0503,  ...,  0.0483, -0.0793, -0.0889],
         [-0.0487, -0.0549, -0.0973,  ..., -0.0006,  0.0926, -0.0764],
         [-0.0916, -0.0943,  0.0814,  ..., -0.0793, -0.0456, -0.0537]],
        grad_fn=<SliceBackward0>))

In [72]:
lstm.all_weights[0][0].shape,lstm.all_weights[0][1].shape,  lstm.all_weights[0][2].shape, lstm.all_weights[0][3].shape, 

(torch.Size([8192, 1024]),
 torch.Size([8192, 2048]),
 torch.Size([8192]),
 torch.Size([8192]))

In [73]:
lstm.all_weights[2][0].shape,lstm.all_weights[2][1].shape,  lstm.all_weights[2][2].shape, lstm.all_weights[2][3].shape, 

(torch.Size([8192, 4096]),
 torch.Size([8192, 2048]),
 torch.Size([8192]),
 torch.Size([8192]))

In [74]:
lstm.all_weights[4][3], lstm.bias_hh_l2

(Parameter containing:
 tensor([-0.0824,  0.0584, -0.0415,  ..., -0.0184,  0.0727, -0.0223],
        requires_grad=True),
 Parameter containing:
 tensor([-0.0824,  0.0584, -0.0415,  ..., -0.0184,  0.0727, -0.0223],
        requires_grad=True))

In [75]:
lstm.all_weights[5][3], lstm.bias_hh_l2_reverse

(Parameter containing:
 tensor([ 0.0708,  0.0131, -0.0339,  ..., -0.0379,  0.0161,  0.0285],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0708,  0.0131, -0.0339,  ..., -0.0379,  0.0161,  0.0285],
        requires_grad=True))

In [76]:
out, (hn, cn) = lstm(embedding_net(src))

In [77]:
out.shape

torch.Size([4, 256, 4096])

In [78]:
hn.shape

torch.Size([6, 4, 2048])

In [79]:
cn.shape

torch.Size([6, 4, 2048])

In [80]:
hn[0]

tensor([[ 0.0887, -0.2716, -0.0035,  ..., -0.1872, -0.2095,  0.2743],
        [ 0.2145,  0.4700, -0.0473,  ...,  0.0824, -0.2775,  0.1891],
        [-0.1518,  0.3974, -0.0902,  ...,  0.0360, -0.4124,  0.1423],
        [-0.0665, -0.2272, -0.2140,  ..., -0.1708, -0.3700,  0.1284]],
       grad_fn=<SelectBackward0>)

In [81]:
hn[1]

tensor([[ 0.2423,  0.4435,  0.3883,  ...,  0.0613, -0.2588,  0.1261],
        [-0.1673, -0.0230,  0.2761,  ..., -0.1330, -0.0902, -0.2122],
        [ 0.3518,  0.5100,  0.3094,  ...,  0.0697, -0.0110, -0.0222],
        [ 0.3315, -0.0788,  0.2140,  ...,  0.0008, -0.0609,  0.1283]],
       grad_fn=<SelectBackward0>)

In [82]:

encoder_config = {
    'vocab_size': SRC_SIZE,
    'embed_dim': EMBED_SIZE, 
    'embed_net': encoder_embed,
    'hidden_size': HIDDEN_SIZE,
    'bidirectional': True, 
    'model_type': 'lstm',
    'batch_first': True,
    'num_layers': 2,
    'dropout_in': 0.1, 
    'dropout_out': 0.1,

    }

In [83]:
encoder_embed.weight.shape

torch.Size([40000, 1024])

In [84]:
def get_dummy_input(batch_size=BATCH_SIZE):
    src = torch.randint(0, VOCAB_SIZE, (batch_size, SRC_SIZE))
    mask_idx = torch.randint(5, SRC_SIZE, (batch_size, 1))
    mask = torch.arange(SRC_SIZE, dtype=mask_idx.dtype) > mask_idx
    src[mask] = PAD_IDX
    src[torch.arange(batch_size), mask_idx.flatten()] = EOS_IDX 
    src[torch.arange(batch_size), [0] * batch_size] = BOS_IDX 
    src_lengths = (src !=0).sum(axis=1)
    return src, src_lengths
    

In [85]:
class Dropout(nn.Module):
    
    def __init__(self, p=0.0):
        super().__init__()
        self.p = p
    
    def forward(self, x, inplace=False):
        if self.p > 0 and self.training:
            return F.dropout(x, self.p, training=True, inplace=inplace)
        else:
            return x

In [86]:
torch.manual_seed(0)
x = torch.randn(5, 4, 8)
drop_net = Dropout(0.1)

In [87]:
drop_net.training

True

In [88]:
drop_net.train()

Dropout()

In [89]:
y = drop_net(x)
y[2]

tensor([[-0.0000e+00,  1.0222e+00,  1.2342e+00,  1.4332e+00, -1.6424e+00,
          2.8525e+00, -5.2569e-01,  3.7283e-01],
        [-1.8104e+00, -6.1083e-01, -5.3315e-01, -0.0000e+00, -1.1855e+00,
          1.2388e+00, -1.5630e-01,  8.9528e-01],
        [-1.0372e-01,  7.6339e-01, -9.3146e-01,  9.9091e-04,  9.3544e-01,
         -4.4448e-01,  1.1550e+00,  3.9795e-01],
        [-2.7333e-01,  2.5584e+00, -2.0908e+00, -5.5252e-02, -1.1611e+00,
         -1.0628e+00,  0.0000e+00,  7.8899e-01]])

In [90]:
drop_net.training

True

In [91]:
drop_net.eval()

Dropout()

In [92]:
drop_net.training

False

In [93]:
y = drop_net(x)
y[2]

tensor([[-5.6925e-01,  9.1997e-01,  1.1108e+00,  1.2899e+00, -1.4782e+00,
          2.5672e+00, -4.7312e-01,  3.3555e-01],
        [-1.6293e+00, -5.4974e-01, -4.7983e-01, -4.9968e-01, -1.0670e+00,
          1.1149e+00, -1.4067e-01,  8.0575e-01],
        [-9.3348e-02,  6.8705e-01, -8.3832e-01,  8.9182e-04,  8.4189e-01,
         -4.0003e-01,  1.0395e+00,  3.5815e-01],
        [-2.4600e-01,  2.3025e+00, -1.8817e+00, -4.9727e-02, -1.0450e+00,
         -9.5650e-01,  3.3532e-02,  7.1009e-01]])

In [94]:
class SeqEncoderModel(nn.Module):
    
    def __init__(self, encoder_config):
        super(SeqEncoderModel, self).__init__()
        self.vocab_size = encoder_config['vocab_size']
        self.embed_net = encoder_config['embed_net']
        self.hidden_size = encoder_config['hidden_size']
        self.num_layers = encoder_config['num_layers']
        self.batch_first = encoder_config['batch_first']
        self.bidirectional = encoder_config['bidirectional']
        self.dropout_in = encoder_config['dropout_in']
        self.dropout_out = encoder_config['dropout_out']
        
        self.model = self.get_model(encoder_config['model_type'])
        self.dropout_in_model = Dropout(self.dropout_in)
        self.dropout_out_model = Dropout(self.dropout_out)
        
        self.num_direction = 2 if self.bidirectional else 1
        self.units = self.hidden_size * 2 if self.bidirectional else 1
    
    def get_model(self, model_type, **kwargs):
        if model_type == 'lstm':
            return LSTM(embed_size=self.embed_net.weight.shape[1],
                        hidden_size=self.hidden_size,
                        num_layers=self.num_layers,
                        bidirectional=self.bidirectional,
                        dropout=self.dropout_out if self.num_layers > 1 else 0.0,
                        batch_first=self.batch_first,
                        )
        else:
            raise NotImplementedError
    def forward(self, src_tokens, src_lengths, ):
        bs, ts = src_tokens.size()
        x = self.embed_net(src_tokens)
        x = self.dropout_in_model(x)
        h0, c0 = torch.zeros(self.num_layers * self.num_direction, bs, self.hidden_size).to(device=src_tokens.device), \
            torch.zeros(self.num_layers * self.num_direction, bs, self.hidden_size).to(device=src_tokens.device)
        out, (hn, cn) = self.model(x, (h0, c0))
        x = self.dropout_out_model(out)
        
        hn = self.reshape_hidden(hn, bs)
        cn = self.reshape_hidden(cn, bs)
        
        return x, (hn, cn)
    
    def reshape_hidden(self, h, batch_size):
        if not self.bidirectional:
            return h
        h = h.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous()
        h = h.view(self.num_layers, batch_size, -1)
        return h
        
        

In [95]:
encoder = SeqEncoderModel(encoder_config)

In [96]:
encoder.model

LSTM(1024, 2048, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)

In [97]:
src, src_lengths = get_dummy_input()

In [98]:
src.shape

torch.Size([4, 256])

In [99]:
src_lengths.shape

torch.Size([4])

In [100]:
src_lengths[-3]

tensor(59)

In [101]:
src[-3]

tensor([    2,  7468, 19430, 18143,  5529,  2244,  8060, 15349, 26755, 15458,
        27830,  2026, 26136,  8813, 25890,  5216, 16913,  6806,  3128, 37882,
        15613,  1872,  3989, 37173, 31764, 30318, 12432, 38457,  5387, 27609,
         6411, 17284, 24931, 23437, 24264,  2216, 35462, 13931, 32271, 19423,
        27902, 15291,  2932, 27558,  5229, 31215, 38406, 12103, 30495,  7789,
         2491, 16002, 31367, 25722, 39821, 17545, 33764, 36206,     3,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [102]:
x, (hn, cn) = encoder(src, src_lengths)

In [103]:
x.shape

torch.Size([4, 256, 4096])

In [104]:
hn.shape

torch.Size([2, 4, 4096])

In [105]:
cn.shape

torch.Size([2, 4, 4096])

## Decoder

In [106]:
def Linear(in_units, out_units, bias=True):
    lin = nn.Linear(in_features=in_units, out_features=out_units, bias=bias)
    lin.weight.data.uniform_(-0.1, 0.1)
    if bias:
        lin.bias.data.uniform_(-0.1, 0.1)
    
    return lin

In [107]:
def Identity():
    return lambda x: x

In [108]:
def LSTMCell(input_size, hidden_size, bias=True):
    cell = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size, bias=bias)
    for name, param in cell.named_parameters():
        if "weight" in name or "bias" in name:
            param.data.uniform_(-0.1, 0.1)
    return cell

In [109]:
class AttentionLayer(nn.Module):
    
    def __init__(self, input_size, src_hidden_size, out_size, bias=False):
        super(AttentionLayer, self).__init__()
        self.inp_proj = Linear(input_size + src_hidden_size, src_hidden_size, bias=bias)
        self.out_proj = Linear(input_size + src_hidden_size, out_size, bias=bias)
        self.scale = nn.Parameter(torch.empty(src_hidden_size, dtype=torch.float32).uniform_(-0.1, 0.1), requires_grad=True)
        # self.score_net = Linear(src_hidden_size + src_hidden_size)
    def forward(self, input_hidden, src_hiddens, src_lengths=None):
        bs, ts, hs = src_hiddens.size()
        # bs ts hs
        x = input_hidden.expand(ts, -1, -1).transpose(0, 1)
        x = torch.cat([src_hiddens, x], dim=2)
        # bs, ts, hs
        score_proj = self.inp_proj(x)
        score_proj = F.tanh(score_proj)
        # bs, ts
        score_proj = score_proj @ self.scale
        assert tuple(score_proj.size()) == tuple((bs, ts))
        
        if src_lengths is not None:
            mask_idx = torch.arange(1, ts + 1).to(device=src_lengths.device)
            mask_idx = mask_idx.tile((bs, 1))
            mask = mask_idx < src_lengths.unsqueeze(dim=1)
            score_proj.float().masked_fill_(mask, -torch.inf)

        # bs, ts
        attn_scores = F.softmax(score_proj, dim=1)
        context_vector = src_hiddens * attn_scores.unsqueeze(dim=2)
        # bs, hs
        context_vector = context_vector.sum(axis=1)
        
        out = F.tanh(self.out_proj(
            torch.cat([input_hidden, context_vector], dim=1)
        ))
        
        return out, attn_scores

In [110]:
decoder_config = {
    'vocab_size': VOCAB_SIZE,
    'embed_dim': decoder_embed.weight.size(1), 
    'embed_net': decoder_embed,
    'hidden_size': HIDDEN_SIZE,
    'decoder_only': True,
    'encoder_units': encoder.units, 
    'batch_first': True,
    'num_layers': 2,
    'dropout_in': 0.1, 
    'dropout_out': 0.1,
    'attention': True,
    'residual': True

    }

In [111]:
class SeqDecoderModel(nn.Module):
    
    def __init__(self, decoder_config):
        super(SeqDecoderModel, self).__init__()
        self.vocab_size = decoder_config['vocab_size']
        self.embed_net = decoder_config['embed_net']
        self.embed_dim = decoder_config['embed_dim']
        self.hidden_size = decoder_config['hidden_size']
        self.num_layers = decoder_config['num_layers']
        self.batch_first = decoder_config['batch_first']
        self.encoder_units = decoder_config['encoder_units']
        self.dropout_in = decoder_config['dropout_in']
        self.dropout_out = decoder_config['dropout_out']
        self.decoder_only = decoder_config['decoder_only']
        self.residual = decoder_config['residual']
        
        self.dropout_in_model = Dropout(self.dropout_in)
        self.dropout_out_model = Dropout(self.dropout_out)
        
        # input_feed_size = 0 if self.encoder_units > 0 else self.hidden_size
        input_feed_size = 0 if decoder_config['decoder_only'] else self.hidden_size
        
        if self.encoder_units != self.hidden_size:
            self.hn_proj = Linear(self.encoder_units, self.hidden_size)
            self.cn_proj = Linear(self.encoder_units, self.hidden_size)
        else:
            self.hn_proj = Identity()
            self.cn_proj = Identity()
        
        self.layers = nn.ModuleList([
            LSTMCell(input_size=input_feed_size + self.embed_dim if layer == 0 else self.hidden_size,
                     hidden_size=self.hidden_size)
            for layer in range(self.num_layers)
            
        ])
        
        if decoder_config['attention']:
            self.attention = AttentionLayer(self.hidden_size, self.encoder_units, self.hidden_size, bias=False)
        else:
            self.attention = None
        
        self.out_embed_net = Linear(self.hidden_size, self.embed_dim)
        
        self.classifier = Linear(self.embed_dim, self.vocab_size)
        
    def forward(self, tgt_tokens, context, ):
        bs, ts = tgt_tokens.size()
        encoder_out, (encoder_hn, encoder_cn), src_lengths = context
        
        # bs, ts, emb_size
        x = self.embed_net(tgt_tokens)
        x = self.dropout_in_model(x)
        # ts, bs, emb_size
        x = x.transpose(0, 1)
        
        if self.decoder_only:
            previous_hn = x.new_zeros(self.num_layers, bs, self.hidden_size)
            previous_cn = x.new_zeros(self.num_layers, bs, self.hidden_size)
            input_feed = None
        else:
            previous_hn = [self.hn_proj(encoder_hn[i]) for i in range(self.num_layers)]
            previous_cn = [self.cn_proj(encoder_cn[i]) for i in range(self.num_layers)]
            input_feed = x.new_zeros(bs, self.hidden_size)
        
        outs = []
        attn_scores = x.new_zeros(bs, ts, encoder_out.size(1)) if self.attention else None
        for seq in range(ts):
            if input_feed is None:
                input = x[seq, ...]
            else:
                input = torch.cat((x[seq, ...], input_feed), dim=1)
            
            for i in range(self.num_layers):
                hn, cn = self.layers[i](input, (previous_hn[i], previous_cn[i]))
                input = self.dropout_out_model(hn)
                if self.residual:
                    input = input + previous_hn[i]
                previous_hn[i] = hn
                previous_cn[i] = cn
            
            if self.attention:
                out, scores = self.attention(hn, encoder_out, src_lengths)
                attn_scores[:, seq, :] = scores
            else:
                out = hn
            
            out = self.dropout_out_model(out)
            outs.append(out)
            if input_feed is not None:
                input_feed = out
        

        x = torch.stack(outs, dim=0).view(ts, bs, -1)
        x = x.transpose(0, 1)
        
        x = self.out_embed_net(x)
        x = self.dropout_out_model(x)
        
        x = self.classifier(x)
        
        return x, attn_scores
              

In [112]:
encoder = SeqEncoderModel(encoder_config)

In [113]:
decoder = SeqDecoderModel(decoder_config)

In [114]:
src, src_lengths = get_dummy_input()

In [115]:
tgt, tgt_lengths = get_dummy_input()

In [116]:
enc_out, (enc_hn, enc_cn) = encoder(src, src_lengths)

In [117]:
enc_out.shape

torch.Size([4, 256, 4096])

In [118]:
src_lengths.shape

torch.Size([4])

In [119]:
context = (enc_out, (enc_hn, enc_cn), src_lengths)
out, attn = decoder(tgt, context)

torch.Size([256, 4, 2048])


In [120]:
out.shape

torch.Size([4, 256, 40000])

In [121]:
attn.shape

torch.Size([4, 256, 256])

In [122]:
attn[0][0].topk(2)

torch.return_types.topk(
values=tensor([0.1071, 0.0727], grad_fn=<TopkBackward0>),
indices=tensor([190, 191]))