# Install and Import packages

In [1]:

! pip install wandb

Collecting wandb
  Downloading wandb-0.16.6-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.45.0-py2.py3-none-any.whl (267 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m267.1/267.1 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->w

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
from torch.utils.data import DataLoader

import numpy as np
import wandb
from tqdm import tqdm
from transformers import AutoTokenizer

import random
import collections
from time import time
import copy
import math

# Check environment

In [3]:
# check if the system is running on colab or macbook m1 pro
import platform
import os
op_system = platform.system()


# Import Data

In [4]:
if op_system == 'Darwin':
    # Macbook
    data_path = '../../data/translation/wmt14-en-de/'
elif op_system == 'Linux':
    # Colab
    from google.colab import drive
    drive.mount('/content/drive')

    data_path = '/content/drive/MyDrive/colab_data/from_scratch/data/translation/wmt14-en-de/'

Mounted at /content/drive


# Config

In [5]:


run_name = "self_implemented_transformer_not_converging"
wandb_project_name = "from_scratch_transformer_colab"

check_point_folder_path = data_path + "/check_point"

device_type = "cuda"
if not torch.cuda.is_available():
  if op_system == 'Darwin':
    device_type = "mps"
  elif op_system == 'Linux':
    device_type = "cpu"
device = torch.device(device_type)




BATCH_SIZE = 32 if device_type == "mps" else 12
SEQ_LEN = 64 if device_type == "mps" else 512
ENCODER_LAYER_NUM = 6
DECODER_LAYER_NUM = 6
D_MODEL = 256 if device_type == "mps" else 512
HIDDEN_DIM = 512 if device_type == "mps" else 2048
NUM_HEADS = 8
DROPOUT = 0.1
tokenizer = AutoTokenizer.from_pretrained("gpt2",pad_token="<pad>",bos_token="<sos>",eos_token="<eos>",
                                                       add_bos_token=True, add_eos_token=True,max_length=SEQ_LEN, padding="max_length")
VOCAB_SIZE = len(tokenizer.vocab)
EPOCHS = 50 if device_type == "mps" else 3
STEPS = 1000000
BETA1 = 0.9
BETA2 = 0.98
EPSILON = 1e-9
LEARNING_RATE = 0.00001
WARMUP_STEPS = 4000
TRAIN_DATA_SIZE = 5000 if device_type == "mps" else "all"
TEST_DATA_SIZE = 1000 if device_type == "mps" else "all"

STEP_LOSS_REPORT = 100 if device_type == "mps" else 200
TEST_BLEU_REPORT = 200 if device_type == "mps" else 2000
REPORT_WANDB = False if device_type in ["mps", "cpu"] else True
seed_value = 42
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:

print(len(tokenizer.vocab))
print(device_type)
print(device)

50260
cuda
cuda


# Main

## Data Loader

In [7]:
class WMT14ENDEDatasetHuggingFace(data.Dataset):

    def __init__(self, en_raw_file_path= "",de_raw_file_path="",
               max_len=512, device="cuda",data_size:str="all"):
        self.device = device
        if data_size == "all":
            data_size = "all"
        else:
            data_size = int(data_size)
        with open(en_raw_file_path, "r") as f:
            if data_size != "all":
                en_sentence = f.readlines()[:data_size]
            else:
                en_sentence = f.readlines()
        with open(de_raw_file_path, "r") as f:
            if data_size !="all":
                de_sentence = f.readlines()[:data_size]
            else:
                de_sentence = f.readlines()
        assert len(en_sentence) == len(de_sentence), "The number of english and german sentences should be the same"
        self.data = list(zip(en_sentence, de_sentence))
        self.max_len = max_len
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2",pad_token="<pad>",bos_token="<sos>",eos_token="<eos>",
                                                       add_bos_token=True, add_eos_token=True,max_length=max_len)
        self.tokenizer.add_special_tokens({"pad_token": "<pad>", "bos_token": "<sos>", "eos_token": "<eos>"})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        en_sentence_str,de_sentence_str = self.data[idx]
        en_sentence_str = "<sos> " + en_sentence_str.strip() + " <eos>"
        de_sentence_str = "<sos> " + de_sentence_str.strip() + " <eos>"

        # run huggingface tokenizer
        en_sentence = self.tokenizer(en_sentence_str, padding="max_length", truncation=True,
                                     max_length=self.max_len, return_tensors="pt",add_special_tokens=True)
        de_sentence = self.tokenizer(de_sentence_str, padding="max_length", truncation=True,
                                     max_length=self.max_len, return_tensors="pt",add_special_tokens=True
                                     )


        en_sentence_id = en_sentence["input_ids"].squeeze().to(self.device)
        de_sentence_id = de_sentence["input_ids"].squeeze().to(self.device)
        en_padding_mask = en_sentence["attention_mask"].squeeze().to(self.device)
        de_padding_mask = de_sentence["attention_mask"].squeeze().to(self.device)
        return {
            "en_input_ids": en_sentence_id,
            "de_input_ids": de_sentence_id,
            "en_padding_mask": en_padding_mask,
            "de_padding_mask": de_padding_mask,
            "en_sentence_str": en_sentence_str,
            "de_sentence_str": de_sentence_str
        }

## Evaluation BLEU Score


In [8]:

def _get_ngrams(segment, max_order):
    """
    Extracts all n-grams up to a given maximum order from an input segment.

    :param segment:
        text segment from which n-grams will be extracted
        list of tokens
        ["token1", "token2", "token3", "token4", "token5"]
    :param max_order:
        maximum length of n-grams
    :return:
        a Counter with n-gram counts
    """
    # create a counter to store the n-gram counts
    ngram_counts = collections.Counter()
    # run through all the n-gram from 1 to max_order
    for order in range(1, max_order + 1):
        # run through all the n-gram in the segment
        for i in range(len(segment) - order + 1):
            # get the n-gram, need to convert the n-gram to tuple since list is not hashable
            ngram = tuple(segment[i:i + order])
            # increment the n-gram count
            ngram_counts[ngram] += 1
    # return the n-gram counts, in which the key is the form 1 to max_order n-gram, the value is the frequency of the
    # n-gram
    return ngram_counts

def compute_bleu(reference_corpus, translation_corpus, max_order=4,
                 smooth=False, smooth_value=0.0):
    """
     Implementation of BLEU score.
    :param reference_corpus:
        list of list of reference sentences, each sentence is a list of tokens
        1 st level: number of "list of reference sentences", len is the number of translations
        [
            [
                ["token1", "token2", "token3", "token4", "token5"],
                ["token1", "token2", "token3"]
            ],
            [
                ["token1", "token2 ", "token3", "token4", "token5"],
                ["token1", "token2", "token3"]
            ]
        ]
        2 nd level: number of "reference sentences" for a single translation, len is the number of references
        [
            ["token1", "token2", "token3", "token4", "token5"],
            ["token1", "token2", "token3"]
        ]
        3 rd level: number of tokens in a single reference sentence, len is the number of tokens in a single sentence
        ["token1", "token2", "token3", "token4", "token5"]
    :param translation_corpus:
        list of translated sentences, each sentence is a list of tokens, those sentences are the predicted sentences
        that we want to evaluate
        1 st level: number of "list of translated sentences", len is the number of translations
        [
            ["token1", "token2", "token3", "token4", "token5"],
            ["token1", "token2", "token3"]
        ]
        2 nd level: number of tokens in a single translated sentence, len is the number of tokens in a single sentence
        ["token1", "token2", "token3", "token4", "token5"]
    :param max_order:
        the maximum n-gram order to use when computing BLEU score, usually 4
    :param smooth:
        whether to apply smoothing, default is False, if do not apply smoothing, then the n-gram modified
        precision will be 0 if there is no n-gram overlap, that will make the log of 0, which is undefineda
    :param smooth_value:
        the value to use when applying smoothing, default is 0.0
    :return:
        the BLEU score, the value is between 0 and 1, the higher, the better
    """

    matches_by_order = [0] * max_order
    possible_matches_by_order = [0] * max_order
    reference_length = 0
    translation_length = 0

    for (references, translation) in zip(reference_corpus, translation_corpus):
        # when compute the brevity penalty, we have to consider the shortest reference sentence
        # for the translation sentences, we need to add them up
        reference_length += min(len(r) for r in references)
        translation_length += len(translation)

        # create a counter to store the n-gram counts of the merged reference sentences
        merged_ref_ngram_counts = collections.Counter()
        for reference in references:
            # compute the n-gram counts of every reference sentence, and merge them
            # the merge is not accumulative, it is the keep the maximum count of the n-gram
            # for counter, the + operator will sum the count of the same key, where the | operator will keep the maximum
            merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
        # get the n-gram counts of the translation sentence
        translation_ngram_counts = _get_ngrams(translation, max_order)
        # get the n-gram overlap of the translation sentence and the merged reference sentences
        # the & operator will return the minimum count of the n-gram, if the n-gram is not in the merged reference
        # sentences, then the count will be 0
        overlap = translation_ngram_counts & merged_ref_ngram_counts

        for ngram in overlap:
            # increment the n-gram overlap count
            # len(ngram) is to calculate the order of the n-gram, since the n-gram is a tuple, the length of the tuple
            # minus 1 will be the order of the n-gram
            matches_by_order[len(ngram) - 1] += overlap[ngram]
        for order in range(1, max_order + 1):
            # when calculate precision of the n-gram, we have to consider the possible matches as the denominator
            # if the sentence len is x, then the possible n-gram is x - order + 1
            possible_matches = len(translation) - order + 1
            # to avoid the division by 0, we have to check if the possible matches is greater than 0,
            # that only happens when the order is greater than the length of the sentence
            # for example, if the sentence is ["the", "cat"], the possible bigram is 1, the possible trigram is 0
            if possible_matches > 0:
                # increment the possible n-gram matches count
                possible_matches_by_order[order - 1] += possible_matches
        precision = [0] * max_order
        for i in range(0,max_order):
            if smooth:
                # if one of the n-gram order has no possible matches, then the precision will be 0
                # but we have to avoid the division by 0, so we have to add a smooth value
                precision[i] = (matches_by_order[i] + smooth_value) / (possible_matches_by_order[i] + smooth_value)
            else:
                if possible_matches_by_order[i] > 0:
                    precision[i] = matches_by_order[i] / possible_matches_by_order[i]
                else:
                    precision[i] = 0

        if min(precision) > 0:
            # the reason using geometric mean is that the n-gram precision is highly correlated
            # if they are independent, then we could use the arithmetic mean, but if the triple-gram precision is high,
            # then the bigram precision will be high, so we have to use the geometric mean

            # but using geometric mean will make the result underflow, since the precision is between 0 and 1
            # we will multiply a number between 0 and 1 multiple times, the result will be really small
            # so we have to use the log to avoid the underflow
            # and at the end, we have to use exp to get the final result back, since log then exp will cancel each other
            p_log_sum = sum((1 / max_order) * math.log(p) for p in precision)
            geo_mean = math.exp(p_log_sum)
        else:
            geo_mean = 0

        # compute the brevity penalty
        ratio =  float(translation_length) / reference_length
        if ratio > 1.0:
            bp = 1
        else:
            bp = math.exp(1 - 1.0 / ratio)
        bleu = geo_mean * bp
    return {"bleu": bleu, "geo_mean": geo_mean, "bp": bp,"unigram": precision[0], "bigram": precision[1],
            "trigram": precision[2], "fourgram": precision[3]}

## Model

In [9]:

def clone(component: nn.Module, num_of_copy: int) -> nn.ModuleList:
    """
    In the transformer structure, there will a lot of repeat component, for example, the identical layer of encoders and
    decoders. In order to create those identical components, we will need this clone function to create a list ModuleList
    :param component: the component will be copied
    :param num_of_copy: the number of copies will be in the final module list
    :return: a module list contain num_of_copy component
    """
    return nn.ModuleList([copy.deepcopy(component) for _ in range(num_of_copy)])


def _get_padding_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
    """
    The padding mask is used to prevent the model to look at the padding token
    :param attention_mask:
        the mask is generated by tokenizer, usually the dim is [batch_size * seq_len] contains of 1 and 0, where 1
        represent the position of the corresponding sentence is a meaningful token, otherwise it is a padding.
    :return:
    """
    # create padding mask
    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

    attention_mask = attention_mask[:, None, None, :]
    attention_mask = 1.0 - attention_mask
    attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), torch.finfo(dtype).min)
    return attention_mask


def _get_causal_mask(attention_mask: torch.Tensor, input_shape, dtype: torch.dtype,
                     ) -> torch.Tensor:
    """
    The causal mask is used in decoder, the mask is used to prevent the model to look ahead the future token
    :param attention_mask:
        the mask is generated by tokenizer, usually the dim is [batch_size * seq_len] contains of 1 and 0, where 1
        represent the position of the corresponding sentence is a meaningful token, otherwise it is a padding.
    :param input_shape:
        the shape of the input tensor, tuple of batch_size and seq_len of decoder input
    :param inputs_embeds:
        the embedding of the input tensor, the dim of the input tensor is [batch_size * seq_len * d_model]
    :return:
    """
    # add the past_key_values_length to the seq_len of the input tensor
    key_value_length = input_shape[-1]

    # 4d mask is passed through the layers
    # if the attention_mask is 2D,

    input_shape = (attention_mask.shape[0], input_shape[-1])

    # create causal mask
    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
    causal_4d_mask = None
    if input_shape[-1] > 1:
        if key_value_length is None:
            raise ValueError(
                "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
            )

        bsz, tgt_len = input_shape
        # create a mat that have the same size as attention weight
        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=attention_mask.device)
        # rang a one dim mat only on conditional
        mask_cond = torch.arange(mask.size(-1), device=attention_mask.device)
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

        mask = mask.to(dtype)

        causal_4d_mask = mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)

        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

        bsz, src_len = attention_mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len

        expanded_mask = attention_mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

        inverted_mask = 1.0 - expanded_mask

        expanded_attn_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

        if causal_4d_mask is not None:
            expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)

        # expanded_attn_mask + causal_4d_mask can cause some overflow
        expanded_4d_mask = expanded_attn_mask

    return expanded_4d_mask

### Config Class

In [10]:
class TransformerConfig:
    """
    The config class is an easy way to parse those hyper params into model
    Since nowadays model architecture could be really deep, packing all the hypers into one config obj and pass this obj
    from one component to deeper component is more neat than every __init__ func have a bunch of param
    """

    def __init__(self,
                 d_model: int = 512,
                 num_heads: int = 8,
                 dropout: float = 0.1,
                 batch_size: int = 16,
                 seq_len: int = 256,
                 d_ff: int = 2048,
                 vocab_size: int = 37000,
                 device: str = "cuda",
                 encoder_layer_num: int = 6,
                 decoder_layer_num: int = 6,
                 eps: float = 1e-6
                 ):
        # the main model size of the transformer model, in the whole model,  we will use d_model number vector to
        # represent meaning of the word (the location of this word in the word embedding space)
        self.d_model = d_model
        # number of the heads define when we do the attention operation, parallel, there will be [num_heads] heads using
        # the same inputs but different learnable params to the same operation, the concat res will be the final res of
        # attention operation
        self.num_heads = num_heads
        # the dropout layer is critical in deep learning model, dropout is fantastic technical that can proven the
        # model overfit. what the dropout layer doing is it "cover/cut" random a percentage of input when it is
        # running, so the model won't over-relay on a certain feature/path of the model. it will increase the
        # robustness of the model
        self.dropout = dropout
        self.batch_size = batch_size
        # seq_len is the max number of token the model could process in one operation, not like RNN the model process
        # the input token by token, all the attention computation in transformer could be done at teh same time, we have
        # to define the max number of token, so the model can create weight mat accordingly
        self.seq_len = seq_len
        # the inner layer dim of fully-connected feed-forward component
        self.d_ff = d_ff
        # the vocab size of the tokenized, will be used to generate embedding layer and final fully connected layer
        self.vocab_size = vocab_size
        # indicate where the whole model will be running, all the tensor involved in the computation need to be moved
        # on the same device
        self.device = device
        self.encoder_layer_num = encoder_layer_num
        self.decoder_layer_num = decoder_layer_num
        self.eps = eps

### MultiHeadAttention

In [11]:
class MultiHeadAttention(nn.Module):
    """
    Multi Head attention is a foundation component of transformer model,
    What is does just repeat the scaled dot product attention operation several times parallel, each time we call it a
    Head
    """

    def __init__(self, config: TransformerConfig, is_cross: bool = False):

        super(MultiHeadAttention, self).__init__()
        # since those head are doing the attention operation at the same time, we better put them in a same matrix
        # to make it efficient. In that case, if we define single head dim as d_single_head
        # d_model = num_heads * d_single_head. before we do the scaled dot-product we have to assert, otherwise we can't
        # split the d_model evenly into heads
        self.d_model = config.d_model
        self.num_heads = config.num_heads
        assert self.d_model % self.num_heads == 0, "the number of head need to be divided by d_model"
        # this linear nn.ModuleList contains the W_q,W_k,W_v,W_o. All of them have the same size the purpose the W_q,
        # W_k,W_v is for projection. to do the scale dot-production attention, we have to use query(q) * key(k) to
        # get score between q and k then use the score as weight to retrieve info from the v, but there is an issue,
        # the original input the attention is general. For example in self attention, the original input of attention
        # is the same, 3 identical matrix represent a general meaning of the sentence. to get a better result. We
        # want project the general meaning into a specific space (query space, key space and value space) and use
        # those projected(professional) value to do the scale dot-product this W_o is used when we concat and
        # aggregate each head's value into final attention res since those head might have the same result,
        # some may focus on less important relation between q and k, we need a learnable params to assign weight to
        # each head and their dim
        self.linears = clone(nn.Linear(self.d_model, self.d_model), 4)
        self.dropout = nn.Dropout(p=config.dropout)
        self.seq_len = config.seq_len
        self.d_k = self.d_model // self.num_heads
        self.is_cross = is_cross

    def _scaled_dot_product(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                            mask: torch.Tensor, dropout: nn.Dropout) -> torch.Tensor:
        """
        this function will actually do the scaled dot product mentioned in equation (1)
        for this function, all q k v need to be prepared, which it has already been split into head dim
        for the mask, it also should adjust to proper dim for broadcasting operation
        :param q: [batch_size * num_heads * seq_len * d_k]
        :param k: [batch_size * num_heads * seq_len * d_k]
        :param v: [batch_size * num_heads * seq_len * d_k]
        :param mask: [ batch_size * 1 * seq_len * 1]
        :param dropout: the dropout defined in the outer layer
        :return: the scaled dot-product result [batch_size * num_heads * seq_len * d_k]
        """
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # TODO explain the mask fill and why there is a small value
            scores = scores + mask
        scores = F.softmax(scores, dim=-1)
        # TODO explain why dropout before
        return torch.matmul(scores, v)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                mask: torch.Tensor = None) -> torch.Tensor:
        """
        this function implement the function in section 3.2.2
        :param q:[batch_size * seq_len * d_model]
        :param k:[batch_size * seq_len * d_model]
        :param v:[batch_size * seq_len * d_model]
        :param mask:[batch_size * seq_len]
        :param is_decoder: if the component is decoder, the mask should be different
        :return: result of multi head attention [batch_size * seq_len * d_model]
        """
        # get the batch since the q k v need to be reshaped latter. the number of sentence in the batch won't be all the
        # time the same, for example, the last batch may not be full
        batch_size = q.size(0)
        # TODO explain why have mask here

        # the mask is generated by tokenizer, usually the dim is [batch_size * seq_len] contains of 1 and 0
        # where 1 represent the position of the corresponding sentence is a meaningful token, otherwise it is a
        # padding. in order to use it, mask_fill the score, it has to meet the requirement of broadcasting with score
        # since the dim of score is [batch_size * num_heads * seq_len * d_k], the mask has to un-squeeze at dim 1 and
        # dim -1
        if mask is not None:
            if self.is_cross and self.training:
                # get the causal mask, the mask is used to prevent the model to look ahead the future token
                # the dim of the mask is already [ 1 * 1 * seq_len * seq_len]
                # the shape of the mask is [batch_size * 1 * key_len * memory_len]

                mask = _get_causal_mask(attention_mask=mask,
                                        input_shape = mask.size(),
                                        dtype=q.dtype)

            else:
                # the mask is used to prevent the model to look at the padding token
                # the dim of the mask is already [batch_size * 1 * seq_len * 1]
                if not self.training:
                    mask = _get_padding_mask(attention_mask=mask, dtype=q.dtype)
                else:
                    mask = _get_padding_mask(attention_mask=mask, dtype=q.dtype)

        query, key, value = [l(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for l, x in
                             zip(self.linears, (q, k, v))]
        x = self._scaled_dot_product(query, key, value, mask=mask, dropout=self.dropout)
        # project the input q,k,v into according space to get actual query, key and value
        # q, k, v = [w(mat) for mat, w in zip([q, k, v], self.linears)]
        # reshape the q , k  and v to into heads, constitute the multi head
        # q, k, v = [mat.view(batch_size, self.seq_len, self.num_heads, -1) for mat in [q, k, v]]
        # transpose the number since the matmul only work on last two dim, to calculate the attention, we want to
        # compute q [...... seq_len * d_q] * k [..... d_k, seq_len]
        # after reshaping, the dim is [batch_size * seq_len, num_heads, d_k]
        # so dim 1 and dim 2 need to transpose
        # q, k, v = [torch.transpose(mat, 1, 2) for mat in [q, k, v]]
        # after everything be prepared, the scaled dot-product will be conducted.
        # the output of that func is split into head, we need transpose the dim back and reshape the same dim
        # as the input, so in the transformer the following identical layer could keeping do the same attention
        # operation again and again

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # TODO why need contiguous()
        # x = (self._scaled_dot_product(q=q, k=k, v=v, mask=mask, dropout=self.dropout)
        #      .transpose(1, 2).contiguous().view(batch_size, -1, self.d_model))
        return self.linears[-1](x)

### FeedForward Class

In [12]:
class FeedForward(nn.Module):
    """
    # todo what is purpose this feed forward layer
    """

    def __init__(self, config: TransformerConfig):
        super(FeedForward, self).__init__()
        self.w_1 = nn.Linear(config.d_model, config.d_ff)
        self.w_2 = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        this function implement the equation (2)
        :param x: [batch_size * seq_len * d_model]
        :return: [batch_size * seq_len * d_model]
        """
        # TODO explain the position of the dropout
        # according to the equation, this fully connected feed forward layer, this consists of two linear
        # transformations with a ReLU activation in between.
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

### Embedding Class

In [13]:

class Embedding(nn.Module):
    """
    # todo
    """

    def __init__(self, config: TransformerConfig):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.d_model = torch.tensor(config.d_model).to(config.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.embedding(x) * torch.sqrt(self.d_model)


class PositionalEmbedding(nn.Module):
    """
    Todo
    """

    def __init__(self, config: TransformerConfig):
        super(PositionalEmbedding, self).__init__()
        # todo
        self.dropout = nn.Dropout(p=config.dropout)
        pe = torch.zeros(config.seq_len, config.d_model)
        position = torch.arange(0, config.seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
        """

        :param x:
        :return:
        """
        x = embedding + Variable(self.pe[:, :embedding.size(1)], requires_grad=False)
        return self.dropout(x)

### LayerNorm Class

In [14]:
class LayerNorm(nn.Module):

    def __init__(self, config: TransformerConfig):
        super(LayerNorm, self).__init__()
        self.one_mat = nn.Parameter(torch.ones(config.d_model))
        self.zero_mat = nn.Parameter(torch.zeros(config.d_model))
        self.eps = config.eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        :param x:
        :return:
        """
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.one_mat * (x - mean) / (std + self.eps) + self.zero_mat

### Sublayer Class

In [15]:

class Sublayer(nn.Module):

    def __init__(self, config: TransformerConfig):
        super(Sublayer, self).__init__()
        self.norm = LayerNorm(config)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x: torch.Tensor, module: nn.Module) -> torch.Tensor:

        return self.dropout(module(self.norm(x))) + x



### EncoderLayer Class

In [16]:

class EncoderLayer(nn.Module):

    def __init__(self, config: TransformerConfig):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.sublayers = clone(Sublayer(config), 2)

    def forward(self, x: torch.Tensor, src_masking: torch.Tensor) -> torch.Tensor:
        x = self.sublayers[0](x, lambda x: self.self_attention(x, x, x, src_masking))
        x = self.sublayers[1](x, self.ffn)
        return x



### Encoder Class


In [17]:
class Encoder(nn.Module):

    def __init__(self, config: TransformerConfig):
        super(Encoder, self).__init__()
        self.encoder_layer_list = clone(EncoderLayer(config), config.encoder_layer_num)
        self.norm = LayerNorm(config)

    def forward(self, x: torch.Tensor, src_masking: torch.Tensor) -> torch.Tensor:
        for encoder_layer in self.encoder_layer_list:
            x = encoder_layer(x, src_masking)
        return self.norm(x)

### DecoderLayer Class


In [18]:

class DecoderLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(config)
        self.cross_attention = MultiHeadAttention(config, is_cross=True)
        self.ffn = FeedForward(config)
        self.sublayers = clone(Sublayer(config), 3)

    def forward(self, memory: torch.Tensor, x: torch.Tensor, src_masking: torch.Tensor,
                tgt_masking: torch.Tensor) -> torch.Tensor:
        x = self.sublayers[0](x, lambda x: self.self_attention(x, x, x, tgt_masking))
        if self.training:
            x = self.sublayers[1](x, lambda x: self.cross_attention(x, memory, memory, src_masking))
        else:
            x = self.sublayers[1](x, lambda x: self.cross_attention(x, memory, memory, src_masking))
        x = self.sublayers[2](x, self.ffn)
        return x

### Decoder Class


In [19]:

class Decoder(nn.Module):

    def __init__(self, config: TransformerConfig):
        super(Decoder, self).__init__()
        self.decoder_layer_list = clone(DecoderLayer(config), config.decoder_layer_num)
        self.norm = LayerNorm(config)

    def forward(self, memory: torch.Tensor, x: torch.Tensor,
                src_masking: torch.Tensor, tgt_masking: torch.Tensor) -> torch.Tensor:

        for decoder_layer in self.decoder_layer_list:
            x = decoder_layer(memory, x, src_masking, tgt_masking)
        return self.norm(x)



### Transformer Class


In [20]:
class Transformer(nn.Module):

    def __init__(self, config: TransformerConfig):
        super(Transformer, self).__init__()
        self.embedding = Embedding(config)
        self.pe = PositionalEmbedding(config)
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.linear = nn.Linear(config.d_model, config.vocab_size)

    def forward(self, src_x: torch.Tensor, tgt_x: torch.Tensor,
                src_masking: torch.Tensor, tgt_masking: torch.Tensor,
                memory: tokenizer=None) -> torch.Tensor:
        src_embedding = self.embedding(src_x)
        tgt_embedding = self.embedding(tgt_x)
        src_pe = self.pe(src_embedding)
        tgt_pe = self.pe(tgt_embedding)
        if memory is None:
            memory = self.encoder(src_pe, src_masking)
        output = self.decoder(memory, tgt_pe, src_masking, tgt_masking)

        logits = self.linear(output)
        return {
            "logits": logits,
            "memory": memory,
            "output": output

        }


# RUN

## prepare the data

In [21]:


if not os.path.exists(check_point_folder_path):
    os.makedirs(check_point_folder_path)

if op_system == "Darwin":
   train_data_size = str(TRAIN_DATA_SIZE)
else:
    train_data_size = "all"

checkpoint_files = os.listdir(check_point_folder_path)
checkpoint_file_name = f"checkpoint_{train_data_size}_batch_size-{BATCH_SIZE}_seq_len-{SEQ_LEN}_encoder_layer_num-{ENCODER_LAYER_NUM}_decoder_layer_num-{DECODER_LAYER_NUM}_d_model-{D_MODEL}_hidden_dim-{HIDDEN_DIM}_num_heads-{NUM_HEADS}_dropout-{DROPOUT}_vocab_size-{VOCAB_SIZE}_epochs-{EPOCHS}_steps-{STEPS}_beta1-{BETA1}_beta2-{BETA2}_epsilon-{EPSILON}_learning_rate-{LEARNING_RATE}_warmup_steps-{WARMUP_STEPS}"
if checkpoint_file_name in checkpoint_files:
    # load the model from the checkpoint
    transformer = torch.load(check_point_folder_path + "/" + checkpoint_file_name)
else:
    if REPORT_WANDB:
        wandb.init(
        # set the wandb project where this run will be logged
        project=wandb_project_name,
        name=run_name,

        # track hyperparameters and run metadata
        config={
            "batch_size": BATCH_SIZE,
            "seq_len": SEQ_LEN,
            "encoder_layer_num": ENCODER_LAYER_NUM,
            "decoder_layer_num": DECODER_LAYER_NUM,
            "d_model": D_MODEL,
            "hidden_dim": HIDDEN_DIM,
            "num_heads": NUM_HEADS,
            "dropout": DROPOUT,
            "vocab_size": VOCAB_SIZE,
            "epochs": EPOCHS,
            "steps": STEPS,
            "beta1": BETA1,
            "beta2": BETA2,
            "epsilon": EPSILON,
            "learning_rate": LEARNING_RATE,
            "warmup_steps": WARMUP_STEPS,
            "device": device.type,
            "timestamp": time()
        }
        )

transformer_config = TransformerConfig(
    batch_size=BATCH_SIZE,
    seq_len=SEQ_LEN,
    encoder_layer_num=ENCODER_LAYER_NUM,
    decoder_layer_num=DECODER_LAYER_NUM,
    d_model=D_MODEL,
    d_ff=HIDDEN_DIM,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
    vocab_size=VOCAB_SIZE,
    device=device,
    eps = 1e-6,
)
# fbceb167aa3fe9025474dbf8f741194f210bbc82

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## initialize the model and dataset

In [22]:

transformer = Transformer(transformer_config)
transformer.to(device)

# adam with beta1 = 0.9, beta2 = 0.98, epsilon = 1e-9
optimizer = torch.optim.Adam(transformer.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2), eps=EPSILON)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100000, gamma=0.5)
criterion = torch.nn.CrossEntropyLoss()


wmt14_en_de_tokenizer_dataset = WMT14ENDEDatasetHuggingFace(
    en_raw_file_path= data_path + "raw/train/train.en",
    de_raw_file_path= data_path + "raw/train/train.de",
    device=device, max_len=SEQ_LEN, data_size=train_data_size)

dataloader = DataLoader(wmt14_en_de_tokenizer_dataset, batch_size=BATCH_SIZE, shuffle=True)


## Inferencing Function

### Save the model function

In [23]:
def save_model(model, file_path):
    torch.save(model, file_path)
    print(f"Model saved at {file_path}")

### Compute BLEU Score Function

In [24]:
def model_evaluate(model, tokenizer, device, max_len=SEQ_LEN):
    wmt14_en_de_test_tokenizer_dataset = WMT14ENDEDatasetHuggingFace(
        en_raw_file_path= data_path + "raw/test/newstest2015.en",
        de_raw_file_path= data_path + "raw/test/newstest2015.de",
        device=device, max_len=SEQ_LEN, data_size=TEST_DATA_SIZE)

    test_dataloader = DataLoader(wmt14_en_de_test_tokenizer_dataset, batch_size=8, shuffle=True)
    model.eval()
    ref_sents_list = []
    pred_sents_list = []
    for step, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        batch_en_tensor = data["en_input_ids"]
        padding_mask_en_tensor = data["en_padding_mask"]
        ref_sents = data["de_sentence_str"]
        # assmble the tgt tensor using the token id of the bos token
        batch_de_tensor = torch.tensor([[tokenizer.bos_token_id]] * len(batch_en_tensor), device=device)
        padding_mask_de_tensor = torch.ones_like(batch_de_tensor)
        memory = None

        # data display
        # first src sentence in string
        # print(f"src sentence: {data['en_sentence_str'][0]}")
        # first src sentence in token id
        # print(f"src sentence token id: {batch_en_tensor[0]}")

        # first tgt sentence in string initially is just the bos token
        # print(f"tgt sentence: {tokenizer.decode(batch_de_tensor[0])}")
        # first tgt sentence in token id
        # print(f"tgt sentence token id: {batch_de_tensor[0]}")

        # first ground truth sentence
        # print(f"ground truth sentence: {ref_sents[0]}")


        for i in range(max_len):
            if memory is None:
                res = model(batch_en_tensor, batch_de_tensor, padding_mask_en_tensor, padding_mask_de_tensor)
            else:
                res = model(batch_en_tensor, batch_de_tensor, padding_mask_en_tensor, padding_mask_de_tensor, memory)
            logit = res["logits"]
            memory = res["memory"]
            logit = torch.softmax(logit, dim=-1)
            pred_sents_ids = torch.argmax(logit, dim=-1)
            # append the last token of the pred_sents_ids to the tgt_tensor
            batch_de_tensor = torch.cat([batch_de_tensor, pred_sents_ids[:, -1].unsqueeze(-1)], dim=-1)
            padding_mask_de_tensor = torch.ones_like(batch_de_tensor)
            # if all the last token of the pred_sents_ids is padding token, then break
            if torch.sum(pred_sents_ids[:, -1] == tokenizer.eos_token_id) == len(pred_sents_ids):
               break
        # data display
        # print(f"pred sentence token id: {batch_de_tensor[0]}")
        decoded_sents = tokenizer.batch_decode(batch_de_tensor, remove_special_tokens=True)
        print(f"decoded_sents: {decoded_sents[0]}")

        pred_sents_list.extend(decoded_sents)
        ref_sents_list.extend([[ref_sent] for ref_sent in ref_sents])
        blue_score = compute_bleu(ref_sents_list, pred_sents_list, smooth=True, max_order=4)
    return blue_score

## Main loop

In [None]:
# check if the checkpoint file is already exist, if it is, load the model from the checkpoint
if not os.path.exists(check_point_folder_path):
    os.makedirs(check_point_folder_path)

is_exist = False
checkpoint_files = os.listdir(check_point_folder_path)
if checkpoint_file_name in checkpoint_files:
    # load the model from the checkpoint
    transformer = torch.load(check_point_folder_path + "/" + checkpoint_file_name)
    is_exist = True
else:
    total_step = 0
    for epoch in range(EPOCHS):
        epoch_loss = 0
        for epoch_step, data in tqdm(enumerate(dataloader), total=len(dataloader)):
            # get the batch
            batch_en_tensor = data["en_input_ids"]
            batch_de_tensor = data["de_input_ids"]
            padding_mask_en_tensor = data["en_padding_mask"]
            padding_mask_de_tensor = data["de_padding_mask"]
            # forward pass
            optimizer.zero_grad()
            model_res = transformer(batch_en_tensor, batch_de_tensor, padding_mask_en_tensor, padding_mask_de_tensor)
            logit = model_res["logits"]
            loss = criterion(logit.view(-1, VOCAB_SIZE), batch_de_tensor.view(-1))
            # backward pass
            loss.backward()
            optimizer.step()
            # scheduler.step()
            epoch_loss += loss.item()
            total_step += 1
            if total_step % 100 == 0:
                if REPORT_WANDB:
                    wandb.log({"step": total_step, "loss": loss.item()})
            if total_step % STEP_LOSS_REPORT == 0:
                print(f"Step: {total_step}, Loss: {loss.item()}")
            if total_step % TEST_BLEU_REPORT == 0:
                blue_dict = model_evaluate(transformer, tokenizer, device, max_len=SEQ_LEN)

                if REPORT_WANDB:
                    wandb.log({"step": total_step, "bleu": blue_dict["bleu"]})

                print(f"Step: {total_step}, BLEU Score: {blue_dict}")
        if REPORT_WANDB:
            wandb.log({"epoch": epoch, "loss": epoch_loss/len(dataloader)})

        print(f"Epoch: {epoch}, Loss: {epoch_loss/len(dataloader)}")

save_model(transformer, check_point_folder_path + "/" + checkpoint_file_name)


blue_dict = model_evaluate(transformer, tokenizer, device, max_len=SEQ_LEN)
print(f"BLEU Score: {blue_dict}")









if REPORT_WANDB:
    wandb.finish()






  0%|          | 200/372404 [01:11<36:46:18,  2.81it/s]

Step: 200, Loss: 1.4369796514511108


  0%|          | 400/372404 [02:21<35:51:47,  2.88it/s]

Step: 400, Loss: 0.9987669587135315


  0%|          | 600/372404 [03:31<35:37:15,  2.90it/s]

Step: 600, Loss: 0.6785897612571716


  0%|          | 800/372404 [04:41<35:44:02,  2.89it/s]

Step: 800, Loss: 0.7104177474975586


  0%|          | 1000/372404 [05:51<35:44:48,  2.89it/s]

Step: 1000, Loss: 0.5335140824317932


  0%|          | 1200/372404 [07:00<36:55:30,  2.79it/s]

Step: 1200, Loss: 0.5194545388221741


  0%|          | 1400/372404 [08:10<35:51:36,  2.87it/s]

Step: 1400, Loss: 0.4435024559497833


  0%|          | 1600/372404 [09:20<35:43:17,  2.88it/s]

Step: 1600, Loss: 0.3893871009349823


  0%|          | 1800/372404 [10:30<35:47:43,  2.88it/s]

Step: 1800, Loss: 0.3602859675884247


  1%|          | 1999/372404 [11:39<35:23:43,  2.91it/s]

Step: 2000, Loss: 0.29966822266578674



  0%|          | 0/272 [00:00<?, ?it/s][A
  0%|          | 1/272 [00:16<1:12:47, 16.12s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  1%|          | 2/272 [00:31<1:10:43, 15.72s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  1%|          | 3/272 [00:47<1:09:56, 15.60s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  1%|▏         | 4/272 [01:02<1:09:37, 15.59s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  2%|▏         | 5/272 [01:18<1:10:11, 15.77s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  2%|▏         | 6/272 [01:34<1:10:41, 15.95s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  3%|▎         | 7/272 [01:50<1:10:02, 15.86s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  3%|▎         | 8/272 [02:06<1:09:40, 15.83s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  3%|▎         | 9/272 [02:22<1:09:27, 15.85s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  4%|▎         | 10/272 [02:39<1:10:38, 16.18s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  4%|▍         | 11/272 [02:55<1:10:39, 16.24s/it][A

decoded_sents: <sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos><sos>


  4%|▍         | 12/272 [03:11<1:09:54, 16.13s/it][A