# Top-k, Nucleus Sampling and beam-search decoding strategies


Author: **Rafael Ito**  
e-mail: ito.rafael@gmail.com

# 0. Dataset and Description

**Name:**  ParaCrawl corpus  
**Description:** in this notebook we will implement three decoding strategies:
- Top-k
- Nucleus Sampling
- beam-search  

We will use the T5 model with translation tasks to compare if the decoder is implemented right.


# 1. Libraries and packages

## 1.1 Check device

In [1]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device_model = torch.cuda.get_device_name(0)
print('GPU model:', device_model)

GPU model: Tesla P100-PCIE-16GB


## 1.2 Install packages

In [0]:
!pip install -q \
    numpy       \
    torch       \
    sklearn     \
    skorch      \
    matplotlib  \
    sacrebleu   \
    transformers        \
    pytorch-lightning   \

Install apex lib (16-bit precision)

In [3]:
# to use 16-bit precision, apex lib should be installed
# however, installing it with pip ('! pip install -q apex') crash the code
# installing the /
! git clone -q https://www.github.com/nvidia/apex
! pip -q install -v --no-cache-dir /content/apex/
#! pip install -q -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" /content/apex/

fatal: destination path 'apex' already exists and is not an empty directory.
Processing ./apex
Building wheels for collected packages: apex
  Building wheel for apex (setup.py) ... [?25l[?25hdone
  Created wheel for apex: filename=apex-0.1-cp36-none-any.whl size=177174 sha256=afbdfac66a8bef0a078de80f7b4809df26d8bbd9ab5c4c669813804d305707dd
  Stored in directory: /tmp/pip-ephem-wheel-cache-gpt3ljae/wheels/b1/3a/aa/d84906eaab780ae580c7a5686a33bf2820d8590ac3b60d5967
Successfully built apex
Installing collected packages: apex
  Found existing installation: apex 0.1
    Uninstalling apex-0.1:
      Successfully uninstalled apex-0.1
Successfully installed apex-0.1


## 1.3 Import libraries

In [4]:
#-------------------------------------------------
# general
import numpy as np
import pandas as pd
import os
import sys
import pdb
import random
import gzip
import itertools
import collections
from google.colab import drive
from argparse import Namespace
from typing import Dict, List, Tuple
import functools, traceback
import tensorboard
%load_ext tensorboard
#-------------------------------------------------
# HW status
import psutil
import nvidia_smi
from multiprocessing import cpu_count
#-------------------
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
#-------------------------------------------------
# NLP
import re
import nltk
import sacrebleu
from transformers import T5ForConditionalGeneration, T5Tokenizer
#-------------------------------------------------
# PyTorch
import torch
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import Linear
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, Dataset, DataLoader
import torch.nn.functional as F
#-------------------
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
#-------------------------------------------------
# random seed generator
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#-------------------------------------------------
# Suppress some of the logging
# commands got from Diedre's notebook
import logging
logging.getLogger("transformers.configuration_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARNING)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.WARNING)
logging.getLogger("lightning").setLevel(logging.WARNING)
#-------------------------------------------------
# package version
print('Torch version:', torch.__version__)
print('Pytorch Lightning version:', pl.__version__)

Torch version: 1.5.0+cu101
Pytorch Lightning version: 0.7.6


## 1.4 Device info

In [5]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    device_model = torch.cuda.get_device_name(0)
    device_memory = torch.cuda.get_device_properties(device).total_memory / 1e9
#----------------------------
print('Device:', device)
print('GPU model:', device_model)
print('GPU memory: {0:.2f} GB'.format(device_memory))
print('#-------------------')
print('CPU cores:', cpu_count())

Device: cuda
GPU model: Tesla P100-PCIE-16GB
GPU memory: 17.07 GB
#-------------------
CPU cores: 4


## 1.5 Mount Google Drive

In [6]:
# mount Drive to save checkpoints and continue the training from last stop
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 1.6 Constants definition

In [0]:
# initial config
model_name = "t5-small"
batch_size = 128
accumulate_grad_batches = 2
source_max_length = 32
target_max_length = 32
learning_rate = 5e-3

# 2. Custom functions and classes

## 2.1 Functions

### 2.1.1 Network parameters

Function that calculates the number of parameters of a network

In [0]:
'''
description:
    - given a model, this function returns its number of parameters (weight, bias)
#-------------------
positional args:
    - model [torch.nn.Module]: instance of the network
optional args:
    - verbose (default=False) [bool]: if True, print a report with the parameters of each layer
    - all_parameters (default=False) [bool]: 
        if True, return number of all parameters, if False, return only trainable parameters
#-------------------
return:
    - [int] total parameters of the network
''';

In [0]:
def nparam(model, verbose=False, all_parameters=False):
    if(verbose):
        i = 0
        total = 0
        for name, param in model.named_parameters():
            if (param.requires_grad):
                #print('layer ', i, ' name: ', name)
                j = 1
                for dim in param.data.shape:
                    j = j * dim
                print('layer ', i, ': ', name, '; parameters: ', j, sep='')
                i += 1
                total += j
        print('total parameters = ', total)
        return
    else:
        if (all_parameters):
            return sum(p.numel() for p in model.parameters())
        else:
            return sum(p.numel() for p in model.parameters() if p.requires_grad)

### 2.1.2 Memory Leakage on Exception

In [10]:
'''
source:
https://docs.fast.ai/troubleshoot.html#memory-leakage-on-exception
#-------------------
Decorator used to reclaim GPU RAM (check the source website)
'''

'\nsource:\nhttps://docs.fast.ai/troubleshoot.html#memory-leakage-on-exception\n#-------------------\nDecorator used to reclaim GPU RAM (check the source website)\n'

In [0]:
# decorator from Paulo's notebook
def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = sys.exc_info()
            traceback.clear_frames(tb)
            raise type(val).with_traceback(tb) from None
    return wrapper

### 2.1.3 Hardware status

This function can be used with tqdm to monitor CPU & GPU in status bar

function based on Diedre's notebook.

In [0]:
# '''
# function to monitor GPU usage 
# (which should be near 100% during the training loop)
# '''
# def gpu_usage():
#     global handle
#     return str(nvidia_smi.nvmlDeviceGetUtilizationRates(handle).gpu) + '%'

In [0]:
def hardware_status():
    '''
    function that monitors CPU & GPU and returns a 
    dictionary with some of their status
    '''
    res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
    return {"cpu": str(psutil.cpu_percent()) + '%',
            "mem": str(psutil.virtual_memory().percent) + '%',
            "gpu": str(res.gpu) + '%',
            "gpu_mem": str(res.memory) + '%'}

### 2.1.4 Function to load texts

In [0]:
'''
function that reads .txt files and return them as a list
#------------------------------------------------------
parameters:
    path: directory where the .txt files are
'''
# function that reads .txt files and return them as a list
def load_text_pairs(path):
    text_pairs = []
    for line in gzip.open(path, mode='rt'):
        text_pairs.append(line.strip().split('\t'))
    return text_pairs

### 2.1.5 Function to get optimizer

In [0]:
'''
function that returns the optimizer associated to a string
#------------------------------------------------------
parameters:
    opt:
    - 'Adam': returns the Adam optimizer
    - 'SGD':  returns the Stochastic Gradient Descent optimizer
    - otherwise raise an error
    lr: learning rate
'''
def get_optimizer(opt, lr, model):
    if (opt == 'Adam'):
        return Adam( [p for p in model.parameters() if p.requires_grad],
            lr=lr, eps=1e-08)
    elif (opt == 'SGD'):
        return SGD(model.parameters(), lr=lr)
    else:
        raise ValueError(f"Unsupported optimizer: {opt}")

### 2.1.6 Function to get loss function

In [0]:
'''
function that returns the loss function associated to a string
#------------------------------------------------------
parameters:
    loss_func:
    - 'CE':   returns the Cross Entropy loss function
    - 'MSE':  returns the Mean Squared Error loss function
    - otherwise raise an error
'''
def get_loss_func(loss_func):
    if (loss_func == 'CE'):
        return CrossEntropyLoss()
    elif (loss_func == 'MSE'):
        return MSELoss()
    else:
        raise ValueError(f'Unsupported loss function: {loss_func}')


### 2.1.7 Function to get tokenizer

In [0]:
'''
function that returns the tokenizer associated to a string
#------------------------------------------------------
parameters:
    tokenizer:
    - 'BERT': returns the BERT tokenizer
    - otherwise raise an error
'''
def get_tokenizer(tokenizer, model_name):
    if (tokenizer == 'BERT'):
        return BertTokenizer.from_pretrained('bert-base-uncased')
    elif (tokenizer == 'T5'):
        return T5Tokenizer.from_pretrained(model_name, use_bfloat16=True)
    else:
        raise ValueError(f'Unsupported tokenizer: {tokenizer}')

### 2.1.8 Function to get model

In [0]:
'''
function that returns the the network model associated to a string
#------------------------------------------------------
parameters:
    model_name:
    - BERT models:
        - 'bert-base-uncased' (110 M params)
        - 'bert-large-uncased' (340 M params)
    - T5 models:
        - 't5-small' (60 M params)
        - 't5-base'  (220 M params)
        - 't5-large' (770 M params)
        - 't5-3B'    (2.8 B params)
        - 't5-11B'   (11 B params)
    - otherwise raise an error
'''
def get_model(model_name):
    if ((model_name == 'bert-base-uncased') or 
        (model_name == 'bert-large-uncased')):
        return BertTokenizer.from_pretrained(model_name)
    elif ((model_name == 't5-small') or 
            (model_name == 't5-base') or
            (model_name == 't5-large') or 
            (model_name == 't5-3B') or
            (model_name == 't5-11B')):
        return T5ForConditionalGeneration.from_pretrained(model_name, use_bfloat16=True)
    else:
        raise ValueError(f'Unsupported model: {model_name}')

## 2.2 Classes

## 2.2.1 Dataset Class

In [0]:
''' 
max_length is added by 5 because of the prefix: 
["translate", "English", "to", "Portuguese", ":"] 
'''
class ParaCrawlDS(Dataset):
    def __init__(self, text_pairs: List[Tuple[str]], tokenizer,
                 source_max_length: int = 32, target_max_length: int = 32):
        self.tokenizer = tokenizer
        self.text_pairs = text_pairs
        self.source_max_length = source_max_length
        self.target_max_length = target_max_length
        
    def __len__(self):
        return len(self.text_pairs)
    
    def __getitem__(self, idx):
        source, target = self.text_pairs[idx]
        #---------------------------
        # encode source sequence
        source_encode = tokenizer.encode_plus(
            text=f'translate English to Portuguese: {source} </s>',
            max_length=self.source_max_length,
            pad_to_max_length=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt',
        )
        # get mask and tokens' ids of encoded source sequence
        source_mask = torch.squeeze(source_encode['attention_mask'], dim=0)
        source_token_ids = torch.squeeze(source_encode['input_ids'], dim=0)
        #---------------------------
        # encode target sequence
        target_encode = tokenizer.encode_plus(
            text=f'{target} </s>',
            max_length=self.target_max_length,
            pad_to_max_length=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt',
        )
        # get mask and tokens' ids of encoded target sequence
        target_mask = torch.squeeze(target_encode['attention_mask'], dim=0)
        target_token_ids = torch.squeeze(target_encode['input_ids'], dim=0)
        #---------------------------
        return (source_token_ids, source_mask, target_token_ids, 
            target_mask, source, target)

# 3. Dataset Pre-processing

## 3.1 Download dataset

In [0]:
# download dataset (1M train, 20k test)
! wget -q -nc https://storage.googleapis.com/neuralresearcher_data/unicamp/ia376e_2020s1/paracrawl_enpt_train.tsv.gz
! wget -q -nc https://storage.googleapis.com/neuralresearcher_data/unicamp/ia376e_2020s1/paracrawl_enpt_test.tsv.gz

## 3.2 Loading the dataset

Function to load dataset

In [0]:
def load_text_pairs(path):
    text_pairs = []
    for line in gzip.open(path, mode='rt'):
        text_pairs.append(line.strip().split('\t'))
    return text_pairs

Load dataset

In [0]:
x_train = load_text_pairs('paracrawl_enpt_train.tsv.gz')
x_test = load_text_pairs('paracrawl_enpt_test.tsv.gz')

## 3.3 Splitting the dataset

Training/Development split

In [0]:
# shuffle data before spliting
random.shuffle(x_train)
# create validation set with 5k pairs
x_val = x_train[100000:105000]
# truncate the first 100k pairs for training
x_train = x_train[:100000]

Print a few pair samples

In [24]:
for set_name, x in [('training', x_train), ('validation', x_val), ('test', x_test)]:
    print(f'\n{len(x)} samples from {set_name} set')
    print(f'First 3 samples of {set_name} set:')
    for i, (source, target) in enumerate(x[:3]):
        print(f'{i}: source: {source}\n   target: {target}')


100000 samples from training set
First 3 samples of training set:
0: source: The French General THIBAULT was killed, and BRENNIER, who had a command in the affair of the 17th, taken prisoner.
   target: O General francês THIBAULT foi morto, e BRENNIER, que fora um dos comandantes no combate do dia 17, foi feito prisioneiro.
1: source: Before KBDSW09.DLL scanning, make sure your PC no virus and trojan. We recommend HitMalware.
   target: Antes de KBDSW09.DLL digitalização, verifique se o seu PC não e vírus Trojan. Recomendamos HitMalware.
2: source: There are several means of transport through which you can explore the different parts of this country. One can use buses, taxis or car rentals, which is the best and most convenient means of transport.
   target: Existem vários meios de transporte, através da qual pode explorar as diferentes partes do país. Pode-se usar ônibus, táxis ou carro de aluguel, que é os melhores e mais convenientes de meios de transporte.

5000 samples from valid

## 3.4 Testing the Dataloader


In [0]:
tokenizer = T5Tokenizer.from_pretrained(model_name)

In [26]:
text_pairs = [('I like pizza', 'eu gosto de pizza')]
dataset_debug = ParaCrawlDS(
    text_pairs=text_pairs,
    tokenizer=tokenizer,
    source_max_length=source_max_length,
    target_max_length=target_max_length)

dataloader_debug = DataLoader(dataset_debug, batch_size=10, shuffle=True, 
                              num_workers=0)

source_token_ids, source_mask, target_token_ids, target_mask, _, _ = next(iter(dataloader_debug))
#---------------------------
print('source_token_ids:\n', source_token_ids)
print('source_mask:\n', source_mask)
print('target_token_ids:\n', target_token_ids)
print('target_mask:\n', target_mask)
print('#-------------------')
print('source_token_ids.shape:', source_token_ids.shape)
print('source_mask.shape:', source_mask.shape)
print('target_token_ids.shape:', target_token_ids.shape)
print('target_mask.shape:', target_mask.shape)

source_token_ids:
 tensor([[13959,  1566,    12, 21076,    10,    27,   114,  6871,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]])
source_mask:
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])
target_token_ids:
 tensor([[   3,   15,   76,  281,    7,  235,   20, 6871,    1,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]])
target_mask:
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])
#-------------------
source_token_ids.shape: torch.Size([1, 32])
source_mask.shape: torch.Size([1, 32])
target_token_ids.shape: torch.Size([1, 32])
target_mask.shape: torch.Size([1, 32])


# 4. Network Model

## 4.1 Hiperparameters

In [0]:
hyperparameters = {
    # description
    'experiment_name': 'T5_v1',
    #---------------------------
    # training / early stopping
    'max_epochs': 1,
    'patience': 1,
    #---------------------------
    # dataset
    'dataset_class': ParaCrawlDS,
    'training_size': 100_000,
    #'training_size': 1_000,
    'split_train_val': 0.98,
    #---------------------------
    # dataloader
    'batch_size': 64,
    'accumulate_grad_batches': 1,
    'nworkers': 4,
    #---------------------------
    # network architecture
    'tokenizer': 'T5',
    'model_name': "t5-small",
    'source_max_length': 105,
    'target_max_length': 100,
    #---------------------------
    # optimizer
    'loss_func': 'CE',
    'opt_name': 'Adam',
    'lr': 5e-3,
    'scheduling_factor': 0.95,
    #---------------------------
    # others
    'manual_seed': 42,  # RNG seed
}

## 4.2 Network model definition

In [0]:
'''
LightnigModule based on the reference notebook from Diedre.
'''
class T5EnPtTranslator(pl.LightningModule):
    def __init__(self, hparams, dl_shuffle=True, decoder='greedy', k=10, p=0.95, num_beams=1, hugging_face=False):
        super().__init__()
        self.hparams = hparams
        #---------------------------
        self.dl_shuffle = dl_shuffle
        self.DatasetClass = self.hparams.dataset_class
        #---------------------------
        self.model = get_model(self.hparams.model_name)
        self.tokenizer = get_tokenizer(self.hparams.tokenizer, self.hparams.model_name)
        self.loss_func = get_loss_func(self.hparams.loss_func)
        #---------------------------
        self.decoder = decoder
        self.k = k
        self.p = p
        self.num_beams = num_beams
        self.hugging_face = hugging_face

    def forward(self, source_token_ids, source_mask, target_token_ids):
        if self.training:
            # All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size]
            # https://huggingface.co/transformers/model_doc/t5.html#t5forconditionalgeneration
            target_token_ids[target_token_ids == self.tokenizer.pad_token_id] = -100
            outputs = self.model(
                input_ids = source_token_ids,
                attention_mask = source_mask, 
                lm_labels = target_token_ids)
            loss = outputs[0]
            return loss
        #---------------------------
        # use hugging face decoding
        #---------------------------    
        elif self.hugging_face:
            if self.decoder == 'top-k':
                predicted_token_ids = self.model.generate(
                    input_ids = source_token_ids,
                    max_length = self.hparams.target_max_length,
                    do_sample = True,
                    top_k = self.k,
                )
            elif self.decoder == 'top-p':
                predicted_token_ids = self.model.generate(
                    input_ids = source_token_ids,
                    max_length = self.hparams.target_max_length,
                    do_sample = True,
                    top_p = self.p,
                )
            elif self.decoder == 'beam-search':
                predicted_token_ids = self.model.generate(
                    input_ids = source_token_ids,
                    max_length = self.hparams.target_max_length,
                    do_sample = True,
                    num_beams = self.num_beams,
                )
            else:
                ''' greedy '''
                predicted_token_ids = self.model.generate(
                    input_ids = source_token_ids,
                    max_length = self.hparams.target_max_length,
                    do_sample = True,
                )
            return predicted_token_ids
        #---------------------------
        # use my decoding implementations
        #---------------------------
        else:
            if self.decoder == 'top-k':
                # fill tensor with start token ID
                decoded_ids = torch.full(
                    size = (source_token_ids.shape[0], 1),
                    fill_value = self.model.config.decoder_start_token_id,
                    dtype = torch.long)
                # send tensor to device
                decoded_ids = decoded_ids.to(source_token_ids.device)
                #---------------------------
                encoder_hidden_states = self.model.get_encoder()(source_token_ids, attention_mask=source_mask)
                for step in range(self.hparams.target_max_length):
                    logits, _, _ = self.model(
                        decoder_input_ids = decoded_ids,
                        encoder_outputs = encoder_hidden_states,
                        attention_mask = source_mask)
                    next_token_logits = logits[:, -1, :]
#                    pdb.set_trace()
                    #---------------------------
                    #roulette = torch.topk(next_token_logits, k=self.k, sorted=False)[1]
                    #chosen = torch.randint(self.k, (1,1)).item()
                    #next_token_id = roulette[:,chosen].unsqueeze(-1)
                    roulette = torch.topk(next_token_logits, k=self.k, sorted=False)[1]
                    next_token_id = torch.zeros((roulette.shape[0], 1), dtype=roulette.dtype, device=roulette.device)
                    for idx, topk in enumerate(roulette):
                        next_token_id[idx] = roulette[idx][torch.randint(self.k, (1,1)).item()]
                    #---------------------------
                    decoded_ids = torch.cat([decoded_ids, next_token_id], dim=-1)
                return decoded_ids                
            #---------------------------
            elif self.decoder == 'top-p':
                decoded_ids = torch.full(
                    size = (source_token_ids.shape[0], 1),
                    fill_value = self.model.config.decoder_start_token_id,
                    dtype = torch.long)
                # send tensor to device
                decoded_ids = decoded_ids.to(source_token_ids.device)
                #---------------------------
                encoder_hidden_states = self.model.get_encoder()(source_token_ids, attention_mask=source_mask)
                for step in range(self.hparams.target_max_length):
                    logits, _, _ = self.model(
                        decoder_input_ids = decoded_ids,
                        encoder_outputs = encoder_hidden_states,
                        attention_mask = source_mask)
                    next_token_logits = logits[:, -1, :]
                    #pdb.set_trace()
                    #---------------------------
                    # softmax to calculate probs from logits
                    probs = torch.softmax(next_token_logits, dim=1)         # (B x V)
                    # sort probs
                    probs_sorted, indices = torch.sort(probs, descending=True)    # (B x V)
                    # cumulative sum of sorted probs
                    cumulative = torch.cumsum(probs_sorted, dim=1)                # (B x V)
                    # get number of p tokens
                    top_p = (cumulative < self.p).sum(dim=1) + 1            # (B)
                    # set prob=0 for tokens out of p
                    for i, p in enumerate(top_p):
                        probs_sorted[i][p:] = 0
                    # sample from probs of p tokens that left
                    indices_sampled = torch.multinomial(probs_sorted, num_samples=1)
                    # restore original index from indices (returned from torch.sort)
                    next_token_id = torch.clone(top_p).unsqueeze(-1)        # (B x 1)
                    for batch, idx in enumerate(indices_sampled):
                        next_token_id[batch] = indices[batch][idx]
                    #---------------------------
                    decoded_ids = torch.cat([decoded_ids, next_token_id], dim=-1)
                return decoded_ids                
            #---------------------------
            elif self.decoder == 'beam-search':
                print('beam-search')
            #---------------------------
            else:
                ''' greedy decoding '''
                # fill tensor with start token ID
                decoded_ids = torch.full(
                    size = (source_token_ids.shape[0], 1),
                    fill_value = self.model.config.decoder_start_token_id,
                    dtype = torch.long)
                # send tensor to device
                decoded_ids = decoded_ids.to(source_token_ids.device)
                #---------------------------
                encoder_hidden_states = self.model.get_encoder()(source_token_ids, attention_mask=source_mask)
                for step in range(self.hparams.target_max_length):
                    logits, _, _ = self.model(
                        decoder_input_ids = decoded_ids,
                        encoder_outputs = encoder_hidden_states,
                        attention_mask = source_mask)
                    next_token_logits = logits[:, -1, :]
                    next_token_id = next_token_logits.argmax(1).unsqueeze(-1)
                    decoded_ids = torch.cat([decoded_ids, next_token_id], dim=-1)
                return decoded_ids

    def training_step(self, batch, batch_nb):
        # calculate loss
        source_token_ids, source_mask, target_token_ids, target_mask, _, _ = batch
        loss = self(source_token_ids, source_mask, target_token_ids)
        tensorboard_logs = {'batch_train_loss': loss} # log batch training loss in TensorBoard
        #---------------------------
        return {'loss': loss, 
                'log': tensorboard_logs,
                'progress_bar': hardware_status()}

    def training_epoch_end(self, outputs):
        # calculate epoch loss based on mini-batch average loss
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        # log training epoch loss in TensorBoard
        tensorboard_logs = {'epoch_train_loss': avg_loss}
        # send 'log' key to the logger (TensorBoard)
        return {'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # calculate BLEU
        source_token_ids, source_mask, target_token_ids, target_mask, source_original, target_original = batch
        predicted = self(source_token_ids, source_mask, target_token_ids)
        translated = [self.tokenizer.decode(token) for token in predicted]
        #---------------------------
        bleu = torch.tensor(sacrebleu.corpus_bleu(translated, [target_original]).score)
        #print(source_original, target_original, translated, bleu, sep='\n')
        tensorboard_logs = {'batch_valid_bleu': bleu} # log batch validation loss in TensorBoard
        #---------------------------
        return {'step_val_bleu': bleu, 
                'log': tensorboard_logs,
                'progress_bar': hardware_status()}

    def validation_epoch_end(self, outputs):
        # calculate validation epoch loss and accuracy
        avg_bleu = torch.stack([x['step_val_bleu'] for x in outputs]).mean()
        #---------------------------
        tensorboard_logs = {'epoch_val_bleu': avg_bleu}
        tqdm_dict = tensorboard_logs
        #---------------------------
        return {'avg_val_bleu': avg_bleu, 
                'log': tensorboard_logs,
                'progress_bar': tqdm_dict}

    def test_step(self, batch, batch_nb):
        source_token_ids, source_mask, target_token_ids, target_mask, source_original, target_original = batch
        predicted = self(source_token_ids, source_mask, target_token_ids)
        translated = [self.tokenizer.decode(token) for token in predicted]
        bleu = torch.tensor(sacrebleu.corpus_bleu(translated, [target_original]).score)
        if (batch_nb % 100 == 0):
            print(source_original, target_original, translated, bleu, sep='\n')
        return {'test_bleu': bleu,
                'progress_bar': hardware_status()}

    def test_epoch_end(self, outputs):
        avg_test_bleu = torch.stack([x['test_bleu'] for x in outputs]).mean()
        #---------------------------
        tensorboard_logs = {'avg_test_bleu': avg_test_bleu}
        tqdm_dict = tensorboard_logs
        #---------------------------
        return {'avg_test_bleu': avg_test_bleu, 
                'log': tensorboard_logs,
                'progress_bar': tqdm_dict}

    def prepare_data(self):
        # load training/test sets
        x_train = load_text_pairs('paracrawl_enpt_train.tsv.gz')
        x_test = load_text_pairs('paracrawl_enpt_test.tsv.gz')
        #---------------------------
        # shuffle data before spliting
        random.shuffle(x_train)
        # split
        split = int(self.hparams.split_train_val * self.hparams.training_size)
        x_valid = x_train[split:]     # create validation set
        x_train = x_train[:split]   # truncate training set
        #---------------------------
        # create training set
        self.ds_train = self.DatasetClass(
            text_pairs = x_train, 
            tokenizer = self.tokenizer, 
            source_max_length = self.hparams.source_max_length,
            target_max_length = self.hparams.target_max_length)
        # create development set
        self.ds_valid = self.DatasetClass(
            text_pairs = x_valid, 
            tokenizer = self.tokenizer, 
            source_max_length = self.hparams.source_max_length,
            target_max_length = self.hparams.target_max_length)
        # create test set
        self.ds_test = self.DatasetClass(
            text_pairs = x_test, 
            tokenizer = self.tokenizer, 
            source_max_length = self.hparams.source_max_length,
            target_max_length = self.hparams.target_max_length)

    @gpu_mem_restore
    def train_dataloader(self):
        return DataLoader(
            dataset = self.ds_train, 
            batch_size = self.hparams.batch_size,
            drop_last = False,
            shuffle = self.dl_shuffle,
            num_workers=self.hparams.nworkers)

    @gpu_mem_restore
    def val_dataloader(self):                
        return DataLoader(
            dataset = self.ds_valid,
            batch_size = self.hparams.batch_size,
            drop_last = False,
            shuffle = False,
            num_workers=self.hparams.nworkers)
        
    @gpu_mem_restore
    def test_dataloader(self):
        return DataLoader(
            dataset = self.ds_test,
            batch_size = self.hparams.batch_size,
            drop_last = False,
            shuffle = False,
            num_workers=self.hparams.nworkers)

    def configure_optimizers(self):
        optimizer = get_optimizer(self.hparams.opt_name, self.hparams.lr, self.model)
        scheduler = StepLR(optimizer, 1, self.hparams.scheduling_factor)
        return [optimizer], [scheduler]

### Test dataset

In [0]:
x_train = load_text_pairs('paracrawl_enpt_train.tsv.gz')
x_test = load_text_pairs('paracrawl_enpt_test.tsv.gz')
#---------------------------
# shuffle data before spliting
random.shuffle(x_train)
# split
split = int(hyperparameters['split_train_val'] * hyperparameters['training_size'])
x_valid = x_train[split:]     # create validation set
x_train = x_train[:split]   # truncate training set
#---------------------------
# create test set
ds_test = ParaCrawlDS(
    text_pairs = x_test, 
    tokenizer = tokenizer, 
    source_max_length = hyperparameters['source_max_length'],
    target_max_length = hyperparameters['target_max_length'])

### Trainer

In [30]:
trainer_decoder_comparison = pl.Trainer(
    profiler = False,              # do not run profiler
    gpus = 1,                     # GPUs
    precision = 32,               # choose precision (32/16 bits)
    logger = False,               # do not use logging (TensorBoard)
    early_stop_callback = False,  # do not stop early
    checkpoint_callback = False,  # do not save checkpoint
    overfit_pct = 0.1,           # ratio of data to overfit
)

No environment variable for node rank defined. Set as 0.


### Model

In [60]:
model = T5EnPtTranslator(hparams=Namespace(**hyperparameters), dl_shuffle=False)
model.load_state_dict(torch.load('/content/drive/My Drive/parameters'))

<All keys matched successfully>

# Unit test

### Greedy

In [42]:
model.decoder = 'greedy'
model.hugging_face = False
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

### Top-k

In [43]:
model.decoder = 'top-k'
model.k = 1
model.hugging_face = False
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

### Nucleus Sampling

In [44]:
model.decoder = 'top-p'
model.p = 0.01
model.hugging_face = False
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

# Comparison: Greedy

### Hugging face decoder

In [32]:
model.decoder = 'greedy'
model.hugging_face = True
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

### My decoder

In [33]:
model.decoder = 'greedy'
model.hugging_face = False
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

# Comparison: Top-k

### Hugging face decoder

In [53]:
model.decoder = 'top-k'
model.k = 5
model.hugging_face = True
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

### My decoder

In [66]:
model.decoder = 'top-k'
model.k = 5
model.hugging_face = False
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

# Comparison: Nucleus Sampling

### Hugging face decoder

In [40]:
model.decoder = 'top-p'
model.p = 0.95
model.hugging_face = True
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

### My decoder

In [41]:
model.decoder = 'top-p'
model.p = 0.95
model.hugging_face = False
#------------------------------------------------------
trainer_decoder_comparison.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

('In this way, the civil life of a nation matures, making it possible for all citizens to enjoy the fruits of genuine tolerance and mutual respect.', '1999 XIII. Winnipeg, Canada July 23 to August 8', "In the mystery of Christmas, Christ's light shines on the earth, spreading, as it were, in concentric circles.", 'making it viable to drill two new boreholes in the west of that peninsula.', 'His eyes were shining and his voice was cheerful.', 'Injuries, accidents, bereavement, abuse, separation, shock, rape, bullying, harassment, stress, depression, anxiety, eating.', 'Whiteness HP Maxx is a 35% hydrogen peroxide whitening gel for the whitening of vital and non-vital teeth.', 'Lines: with indication of Line Number, From and To ends, insulation, the P&ID where they are drawn.', 'The cruises depart from Manaus, capital of the State of Amazonas, a city in the jungle that prospered during the rubber boom last century and where you will find a smaller copy of the Opera House in Paris, France

# Beam-search (not implemented yet)

### Hugging face decoder

In [0]:
# model.decoder = 'beam-search'
# model.num_beams = 2
# model.hugging_face = True
# #------------------------------------------------------
# trainer_decoder_comparison.test(model)

### My decoder

In [0]:
# model.decoder = 'beam-search'
# model.num_beams = 2
# model.hugging_face = False
# #------------------------------------------------------
# trainer_decoder_comparison.test(model)

# 6. Conclusion

Obs.: the comparison was made using only 10% of the test data, due to the amount of time necessary  

- **My implementation:**  
All algorithms worked. To check this, I tested with k=1, for top-k, and p=0.01 for nucleus sampling. Forcing this values, is the same of doing a greedy decoding, making it possible to compare the exact translation and BLEU score of those obtained using the transformer lib.

- **Hugging Face:**  
As expected, using the transformer lib resulted in lower test time. This happens because the transformer code is optimized.

- **Comparison:**  
For both implementations with other values (ex: k=10, p=0.95) resulted in a much slower BLEU. In th=he top-k the token was sampled uniformly and not weighted. This could explain the BLEU score.

**Unit test**

|            | greedy  | top-k   | top-p   | beam  |
| :--------: |:------: | :-----: | :-----: | :---: |
| parameters |  -----  |  k=1    | p=0.01  | ----- |
|       BLEU | 22.6276 | 22.6276 | 22.6276 | ----- |


**Comparison**

| decoding strategy | hugging face | my implementation |
| :---------------: |:-----------: | :---------------: |
| greedy            |    19.2643   |      22.6276      |
| top-k (k=5)       |    20.8351   |       0.5005      | 
| top-p (p=0.95)    |    21.0985   |      12.1342      |
| beam-search       |    -----     |       -----       |


## End of the notebook