In [1]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

In [2]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

In [3]:
from torch.nn.utils.rnn import pad_sequence

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"DEVICE: {DEVICE}")

DEVICE: cpu


In [5]:
from transformer import Transformer as TestTransformer

In [6]:
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

In [7]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

In [8]:
token_transform = {}
vocab_transform = {}

In [9]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

In [10]:
# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

In [11]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [12]:
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

In [13]:
# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [14]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f9b1eb612f0>

In [15]:
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
print(f"SRC_VOCAB_SIZE: {SRC_VOCAB_SIZE}, TGT_VOCAB_SIZE: {TGT_VOCAB_SIZE}")

SRC_VOCAB_SIZE: 19214, TGT_VOCAB_SIZE: 10837


In [16]:
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

In [17]:
transformer = TestTransformer(
    SRC_VOCAB_SIZE,
    TGT_VOCAB_SIZE,
    EMB_SIZE,
    NHEAD,
    FFN_HID_DIM,
    0.1,
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS
)
print(f"transformer: {transformer}")

transformer: Transformer(
  (encoder): Encoder(
    (positional_embedding): PositionalEmbedding(
      (embedding_layer): Embedding(19214, 512)
      (position_embedding_layer): PositionEncoding()
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (encoder_layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (attention): GlobalSelfAttention(
          (multi_head_attention): MultiHeadAttention(
            (q_linear_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (k_linear_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (v_linear_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (attention_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (attention): ScaledDotProductAttention()
          )
          (dropout): Dropout(p=0.1, inplace=False)
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (f

In [18]:
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [19]:
transformer = transformer.to(DEVICE)

In [20]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [21]:
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

In [23]:
# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

In [24]:
# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor

In [25]:
# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch