In [1]:
import sys
import math
from tqdm import tqdm
sys.path.insert(0, '../')

import torch
from torch import nn
from torch.nn import functional as F

from attention import MultiHeadAttention
from encoder import Encoder
from decoder import Decoder
from positional_encoding import PositionalEncoder

In [11]:
def padding(data: list) -> (list, int):
    max_len = len(max(data, key=len))
    output = [sample + [0]*(max_len-len(sample)) for sample in tqdm(data)]
    return output, max_len

VOCAB_SIZE = 100
data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [20]:
torch.randint(0, 10, (10, 10))

tensor([[1, 9, 8, 4, 5, 0, 3, 0, 1, 4],
        [8, 5, 4, 1, 7, 8, 2, 6, 1, 0],
        [4, 8, 9, 2, 4, 6, 0, 4, 2, 1],
        [8, 7, 3, 2, 3, 3, 4, 0, 9, 7],
        [8, 4, 2, 5, 4, 6, 1, 8, 4, 1],
        [1, 0, 0, 8, 5, 5, 8, 9, 2, 2],
        [8, 3, 6, 0, 0, 4, 8, 7, 0, 7],
        [2, 6, 0, 8, 4, 0, 2, 3, 9, 4],
        [6, 7, 9, 7, 4, 7, 9, 4, 5, 1],
        [6, 6, 0, 7, 3, 9, 1, 1, 6, 6]])

In [17]:
data, MAX_LEN = padding(data)

100%|██████████| 10/10 [00:00<?, ?it/s]


In [18]:
data_tensor = torch.FloatTensor(data)

In [19]:
data_tensor

tensor([[62., 13., 47., 39., 78., 33., 56., 13., 39., 29., 44., 86., 71., 36.,
         18., 75.,  0.,  0.,  0.,  0.],
        [60., 96., 51., 32., 90.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.],
        [35., 45., 48., 65., 91., 99., 92., 10.,  3., 21., 54.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.],
        [75., 51.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.],
        [66., 88., 98., 47.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.],
        [21., 39., 10., 64., 21.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.],
        [98.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.],
        [77., 65., 51., 77., 19., 15., 35., 19., 23., 97., 50., 46., 53., 42.,
         45., 91., 66.,  3., 43., 10.],
        [70., 64., 98., 25., 99., 53.,  4., 13.,