# Transformer test (translation)

Reference: https://pytorch.org/tutorials/beginner/translation_transformer.html

In [1]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

In [2]:
import numpy as np

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

In [3]:
from torch.nn.utils.rnn import pad_sequence

In [4]:
from torch.utils.data import DataLoader

In [5]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"DEVICE: {DEVICE}")

DEVICE: cpu


In [6]:
from utils import SpecialSymbols, sequential_transforms, tensor_transform, token_to_sentence

In [7]:
from transformer import Transformer as TestTransformer

In [8]:
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

In [9]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

## Token transform

In [10]:
token_transform = {}
vocab_transform = {}

In [11]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

In [12]:
print(f'sample src token transform: {token_transform[SRC_LANGUAGE]("hello world")}')
print(f'sample tgt token transform: {token_transform[TGT_LANGUAGE]("hello world")}')

sample src token transform: ['hello', 'world']
sample tgt token transform: ['hello', 'world']


## Vocab transform

In [13]:
# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

In [14]:
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=SpecialSymbols.special_symbols,
                                                    special_first=True)

In [15]:
# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(SpecialSymbols.UNK_IDX)

## Initialize

In [16]:
torch.manual_seed(0)

<torch._C.Generator at 0x7efd9faef330>

In [17]:
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
print(f"SRC_VOCAB_SIZE: {SRC_VOCAB_SIZE}, TGT_VOCAB_SIZE: {TGT_VOCAB_SIZE}")
print(f'sample src vocab transform: {vocab_transform[SRC_LANGUAGE]["hello"]} {vocab_transform[SRC_LANGUAGE]["world"]}')
print(f'sample src vocab transform: {vocab_transform[SRC_LANGUAGE].lookup_tokens((0, 18975))}')
print(f'sample tgt vocab transform: {vocab_transform[TGT_LANGUAGE]["hello"]} {vocab_transform[TGT_LANGUAGE]["world"]}')
print(f'sample tgt vocab transform: {vocab_transform[TGT_LANGUAGE].lookup_tokens((5465, 1870))}')

SRC_VOCAB_SIZE: 19214, TGT_VOCAB_SIZE: 10837
sample src vocab transform: 0 18975
sample src vocab transform: ['<unk>', 'world']
sample tgt vocab transform: 5465 1870
sample tgt vocab transform: ['hello', 'world']


In [18]:
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

In [19]:
transformer = TestTransformer(
    SRC_VOCAB_SIZE,
    TGT_VOCAB_SIZE,
    EMB_SIZE,
    NHEAD,
    FFN_HID_DIM,
    0.1,
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS
)
print(f"transformer: {transformer}")

transformer: Transformer(
  (encoder): Encoder(
    (positional_embedding): PositionalEmbedding(
      (embedding_layer): Embedding(19214, 512)
      (position_embedding_layer): PositionEncoding()
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (encoder_layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (attention): GlobalSelfAttention(
          (multi_head_attention): MultiHeadAttention(
            (q_linear_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (k_linear_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (v_linear_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (attention_projection_func): Linear(in_features=512, out_features=512, bias=False)
            (attention): ScaledDotProductAttention()
          )
          (dropout): Dropout(p=0.1, inplace=False)
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (f

In [20]:
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [21]:
transformer = transformer.to(DEVICE)

In [22]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=SpecialSymbols.PAD_IDX)

In [23]:
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor

In [25]:
# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=SpecialSymbols.PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=SpecialSymbols.PAD_IDX)
    return src_batch, tgt_batch

## Sample data

In [26]:
sample_val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
sample_val_dataloader = DataLoader(sample_val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
sample_data_batch = next(iter(sample_val_dataloader))
# data: (src, tgt)
# src/tgt: (tokens, batch_size)
print(f"src.size(): {sample_data_batch[0].size()}, tgt.size(): {sample_data_batch[1].size()}")
sample_print_size = 5
for i in range(sample_print_size):
    if i >= BATCH_SIZE:
        break
    sample_src_sentence = token_to_sentence(sample_data_batch[0][:, i].numpy(), vocab_transform[SRC_LANGUAGE], True)
    sample_tgt_sentence = token_to_sentence(sample_data_batch[1][:, i].numpy(), vocab_transform[TGT_LANGUAGE], True)
    print(f"""
sample sentence {i}:
    src ({SRC_LANGUAGE}): {sample_src_sentence}
    tgt ({TGT_LANGUAGE}): {sample_tgt_sentence}
""")

src.size(): torch.Size([35, 128]), tgt.size(): torch.Size([30, 128])

sample sentence 0:
    src (de):  Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen                         
    tgt (en):  A group of men are loading cotton onto a truck                   


sample sentence 1:
    src (de):  Ein Mann schläft in einem grünen Raum auf einem Sofa .                       
    tgt (en):  A man sleeping in a green room on a couch .                  


sample sentence 2:
    src (de):  Ein Junge mit Kopfhörern sitzt auf den Schultern einer Frau .                       
    tgt (en):  A boy wearing headphones sits on a woman 's shoulders .                  


sample sentence 3:
    src (de):  Zwei Männer bauen eine blaue <unk> auf einem <unk> See auf                       
    tgt (en):  Two men setting up a blue ice fishing hut on an iced over lake               


sample sentence 4:
    src (de):  Ein Mann mit beginnender Glatze , der eine rote Rettungsweste trägt , sitzt in eine

## Evaluation

In [27]:
def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [28]:
def eval(model, src):
    model.eval()
    src = src.to(DEVICE)
    tgt = model(src, torch.tensor([[SpecialSymbols.BOS_IDX]]))
    return tgt

In [29]:
sample_input = sample_data_batch[0][:, 0]
sample_input_as_batch = sample_data_batch[0][:, 0].unsqueeze(0)
sample_next_token_prob = eval(transformer, sample_input_as_batch)
print(f"sample_input_as_batch: {sample_input_as_batch.size()}, sample_next_token_prob: {sample_next_token_prob.size()}")
print(f"SRC_VOCAB_SIZE: {SRC_VOCAB_SIZE}, TGT_VOCAB_SIZE: {TGT_VOCAB_SIZE}")

sample_next_token_max_prob, sample_next_token_idx = torch.max(sample_next_token_prob, dim=2)
print(f"sample_next_token_idx: {sample_next_token_idx} ({sample_next_token_idx.size()}), sample_next_token_max_prob: {sample_next_token_max_prob}")
sample_next_token = token_to_sentence(np.array([sample_next_token_idx]), vocab_transform[TGT_LANGUAGE], True)
print(f"sample_next_token: {sample_next_token}")

sample_input_as_batch: torch.Size([1, 35]), sample_next_token_prob: torch.Size([1, 1, 10837])
SRC_VOCAB_SIZE: 19214, TGT_VOCAB_SIZE: 10837
sample_next_token_idx: tensor([[5778]]) (torch.Size([1, 1])), sample_next_token_max_prob: tensor([[1.2038]], grad_fn=<MaxBackward0>)
sample_next_token: previously


  output = F.softmax(output)
  sample_next_token = token_to_sentence(np.array([sample_next_token_idx]), vocab_transform[TGT_LANGUAGE], True)
