## Homework 5: Sequence-to-sequence

Author: Cao Yanfei

### Sequence-to-Sequence Introduction
- Typical sequence-to-sequence (seq2seq) models are **encoder-decoder models**, which usually consists of two parts, the encoder and decoder, respectively. These two parts can be implemented with **recurrent neural network (RNN)** or **transformer**, primarily to deal with input/output sequences of dynamic length.
- **Encoder** encodes a sequence of inputs, such as text, video or audio, into a single vector, which can be viewed as the abstractive representation of the inputs, containing information of the whole sequence.
- **Decoder** decodes the vector output of encoder one step at a time, until the final output sequence is complete. Every decoding step is affected by previous step(s). Generally, one would add **"< BOS >"** at the begining of the sequence to indicate start of decoding, and **"< EOS >"** at the end to indicate end of decoding.

![seq2seq](https://i.imgur.com/0zeDyuI.png)

### Homework Description
- English to Chinese (Traditional) Translation
    - Input: An English sentence (e.g. Tom is a student.)
    - Output: The Chinese translation (e.g. 汤姆是个学生。)
- TODO
    - Train a simple **RNN seq2seq** to achieve translation.
    - Switch to **transformer** model to boost performance.
    - Apply **Back-translation** to further boost performance.

### Download and Import Required Packages

In [1]:
# !pip install 'torch>=1.6.0' editdistance matplotlib sacrebleu sacremoses sentencepiece tqdm wandb
# !pip install --upgrade jupyter ipywidgets

# !git clone https://github.com/pytorch/fairseq.git
# !cd fairseq && git checkout 9a1c497
# !pip install --upgrade ./fairseq/

In [26]:
import sys, pdb, pprint, logging, os, random     # pprint: Data pretty printer

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils import data
import numpy as np
import tqdm.auto as tqdm
from pathlib import Path
from argparse import Namespace
from fairseq import utils

import numpy as np
import matplotlib.pyplot as plt

### Fix Random Seed

In [3]:
seed = 73

random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

### Dataset Information

### En-Zn Bilingual Parallel Corpus

- [TED2020](#reimers-2020-multilingual-sentence-bert)
    - Raw: 398,066 (sentences)
    - Processed: 393,980 (sentences)

### Testing Data

- Size: 4,000 (sentences)
- Chinese translation is undisclosed. The provided (.zh) file is pseudo tranlation, each line is a '。'

### Dataset Download

### Install [Megatools](https://megous.com/git/megatools/about/) (optional)

In [4]:
#!apt-get install megatools

### Download and Extract

In [5]:
data_dir = './dataset/rawdata'
dataset_name = 'ted2020'
urls = (
    '"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214989&authkey=AGgQ-DaR8eFSl1A"', 
    '"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214987&authkey=AA4qP_azsicwZZM"',
# # If the above links die, use the following instead. 
#     "https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/ted2020.tgz",
#     "https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/test.tgz",
# # If the above links die, use the following instead. 
#     "https://mega.nz/#!vEcTCISJ!3Rw0eHTZWPpdHBTbQEqBDikDEdFPr7fI8WxaXK9yZ9U",
#     "https://mega.nz/#!zNcnGIoJ!oPJX9AvVVs11jc0SaK6vxP_lFUNTkEcK2WbxJpvjU5Y", 
)
file_names = (
    'ted2020.tgz',    # train & dev
    'test.tgz',       # test
)
prefix = Path(data_dir).absolute() / dataset_name

# prefix.mkdir(parents=True, exist_ok=True)
# for u, f in zip(urls, file_names):
#     path = prefix / f
#     if not path.exists():
#         if 'mega' in u:
#             !megadl {u} --path {path}
#         else:
#             !wget {u} -O {path}
#     if path.suffix == '.tgz':
#         !tar -xvf {path} -C {prefix}
#     elif path.suffix == ".zip":
#         !unzip -o {path} -d {prefix}

# !mv {prefix/'raw.en'} {prefix/'train_dev.raw.en'}
# !mv {prefix/'raw.zh'} {prefix/'train_dev.raw.zh'}
# !mv {prefix/'test.en'} {prefix/'test.raw.en'}
# !mv {prefix/'test.zh'} {prefix/'test.raw.zh'}

### Language

In [6]:
src_lang = 'en'
tgt_lang = 'zh'

data_prefix = f'{prefix}\\train_dev.raw'
test_prefix = f'{prefix}\\test.raw'

# Show contents of the first five lines in designated file. 
# !head {data_prefix+'.'+src_lang} -n 5
# !head {data_prefix+'.'+tgt_lang} -n 5

### Preprocess Files

In [7]:
import re

def str_fw2hw(ustring):
    '''
    Full-width -> Half-width
    '''
    ss = []
    for s in ustring:
        rstring = ''
        for uchar in s:
            inside_code = ord(uchar)                                 # ord() returns the Unicode code point for a one-character string
            if inside_code == 12288:                                 # Full-width space: direct conversion -> Half-width space
                inside_code = 32
            elif (inside_code >= 65281 and inside_code <= 65374):    # Full-width chars (except space) conversion
                inside_code -= 65248
            rstring += chr(inside_code)                              # chr() returns a Unicode string of one character with ordinal i; 0 <= i <= 0x10ffff.
        ss.append(rstring)
    return ''.join(ss)                                               # str.join() concatenate strings in a iterable object to form a new string.

def clean_s(s, lang):
    if lang == 'en':
        s = re.sub(r'\([^()]*\)', '', s)                             # remove ([text])
        s = s.replace('-', '')                                       # str.replace() returns a copy with all occurrences of substring old replaced by new.
        s = re.sub('([.,;!?()\"])', r' \1', s)                       # keep punctuation
    elif lang == 'zh': 
        s = str_fw2hw(s)
        s = re.sub(r'\([^()]*\)', '', s)
        s = s.replace(' ', '')
        s = s.replace('—', '')
        s = s.replace('“', '')
        s = s.replace('”', '')
        s = s.replace('_', '')
        s = re.sub('([。,;!?()\"~「」])', r' \1', s)
    s = ' '.join(s.strip().split())                                  # str.strip() returns a copy of the string with leading and trailing whitespace removed.
                                                                     # str.split() returns a list of the words in the string, using sep as the delimiter string.
    return s

def len_s(s, lang):
    if lang == 'zh':
        return len(s)
    return len(s.split())

def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):
    if Path(f'{prefix}.clean.{l1}').exists() and Path(f'{prefix}.clean.{l2}').exists(): 
        print(f'{prefix}.clean.{l1} & {prefix}.clean.{l2} exists. Skip cleaning.')
        return
    
    with open(f'{prefix}.{l1}', 'r', encoding='UTF-8') as l1_in_f:
        with open(f'{prefix}.{l2}', 'r', encoding='UTF-8') as l2_in_f:
            with open(f'{prefix}.clean.{l1}', 'w', encoding='UTF-8') as l1_out_f:
                with open(f'{prefix}.clean.{l2}', 'w', encoding='UTF-8') as l2_out_f:
                    for s1 in l1_in_f:
                        s1 = s1.strip()
                        s2 = l2_in_f.readline().strip()
                        s1 = clean_s(s1, l1)
                        s2 = clean_s(s2, l2)
                        s1_len = len_s(s1, l1)
                        s2_len = len_s(s2, l2)
                        if min_len > 0:    # Remove short sentence
                            if s1_len < min_len or s2_len < min_len:
                                continue
                        if max_len > 0:    # Remove long sentence
                            if s1_len > max_len or s2_len > max_len:
                                continue
                        if ratio > 0:      # Remove by ratio of length
                            if s1_len / s2_len > ratio or s2_len / s1_len > ratio:
                                continue
                        print(s1, file=l1_out_f)
                        print(s2, file=l2_out_f)

In [8]:
clean_corpus(data_prefix, src_lang, tgt_lang)
clean_corpus(test_prefix, src_lang, tgt_lang, ratio=-1, max_len=-1, min_len=-1)

D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\train_dev.raw.clean.en & D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\train_dev.raw.clean.zh exists. Skip cleaning.
D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\test.raw.clean.en & D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\test.raw.clean.zh exists. Skip cleaning.


In [9]:
# !head {data_prefix+'.clean'+src_lang} -n 5
# !head {data_prefix+'.clean'+tgt_lang} -n 5

### Split into Training or Validation

In [10]:
valid_ratio = 0.01    # 3000~4000 would suffice
train_ratio = 1- valid_ratio

if (prefix / f'train.clean.{src_lang}').exists() \
and (prefix / f'train.clean.{tgt_lang}').exists() \
and (prefix / f'valid.clean.{src_lang}').exists() \
and (prefix / f'valid.clean.{tgt_lang}').exists():
    print(f'training or validation splits exist. Skip spliting.')
else:
    line_num = sum(1 for line in open(f'{data_prefix}.clean.{src_lang}', encoding='UTF-8'))    # 'sum' comprehension
    labels = list(range(line_num))
    random.shuffle(labels)                 # random.shuffle(x): shuffle list x in place, and return None.
    for lang in [src_lang, tgt_lang]:
        train_f = open(os.path.join(data_dir, dataset_name, f'train.clean.{lang}'), 'w', encoding='UTF-8')
        valid_f = open(os.path.join(data_dir, dataset_name, f'valid.clean.{lang}'), 'w', encoding='UTF-8')
        count = 0
        for line in open(f'{data_prefix}.clean.{lang}', 'r', encoding='UTF-8'):
            if labels[count] / line_num < train_ratio:
                train_f.write(line)
            else: 
                valid_f.write(line)
            count += 1
        train_f.close()
        valid_f.close()

training or validation splits exist. Skip spliting.


### Subword Units

Out of vocabulary (OOV) has been a major problem in machine translation. This can be alleviated by using subword units.
- We will use the [sentencepiece](#kudo-richardson-2018-sentencepiece) package
- Select **'unigram' or 'byte-pair encoding (BPE)' algorithm**

In [11]:
import sentencepiece as spm
vocab_size = 8000

if (prefix / f'spm{vocab_size}.model').exists():
    print(f'{prefix}/spm{vocab_size}.model exists. Skip spm_training.')
else:
    spm.SentencePieceTrainer.train(
        input=','.join([f'{prefix}/train.clean.{src_lang}',     # one-sentence-per-line raw corpus file
                        f'{prefix}/valid.clean.{src_lang}',
                        f'{prefix}/train.clean.{tgt_lang}', 
                        f'{prefix}/valid.clean.{tgt_lang}']), 
        model_prefix=prefix / f'spm{vocab_size}',               # output model name prefix. <model_name>.model and <model_name>.vocab are generated.
        vocab_size=vocab_size, 
        character_coverage=1,                                   # amount of characters covered by the model
        model_type='unigram',                                   # 'bpe' works as well
        input_sentence_size=1e6, 
        shuffle_input_sentence=True, 
        normalization_rule_name='nmt_nfkc_cf', 
    )

D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020/spm8000.model exists. Skip spm_training.


In [12]:
spm_model = spm.SentencePieceProcessor(model_file=str(prefix / f'spm{vocab_size}.model'))
in_tag = {
    'train': 'train.clean', 
    'valid': 'valid.clean', 
    'test': 'test.raw.clean', 
}

for split in ['train', 'valid', 'test']: 
    for lang in [src_lang, tgt_lang]: 
        out_path = prefix / f'{split}.{lang}'
        if out_path.exists(): 
            print(f'{out_path} exists. Skip spm_encode.')
        else: 
            with open(prefix / f'{split}.{lang}', 'w', encoding='UTF-8') as out_f: 
                with open(prefix / f'{in_tag[split]}.{lang}', 'r', encoding='UTF-8') as in_f: 
                    for line in in_f: 
                        line = line.strip()
                        tok = spm_model.encode(line, out_type=str)
                        print(' '.join(tok), file=out_f)

D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\train.en exists. Skip spm_encode.
D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\train.zh exists. Skip spm_encode.
D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\valid.en exists. Skip spm_encode.
D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\valid.zh exists. Skip spm_encode.
D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\test.en exists. Skip spm_encode.
D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\test.zh exists. Skip spm_encode.


In [13]:
# !head {data_dir+'/'+dataset_name+'/train.'+src_lang} -n 5
# !head {data_dir+'/'+dataset_name+'/train.'+tgt_lang} -n 5

### Binarize the Data with Fairseq

In [14]:
binpath = Path('./dataset/data-bin', dataset_name)
# if binpath.exists(): 
#     print(binpath, 'exists, will not be overwritten!')
# else:
#     !python -m fairseq_cli.preprocess \
#         --source-lang {src_lang} \
#         --target-lang {tgt_lang} \
#         --trainpref {prefix / 'train'} \
#         --validpref {prefix / 'valid'} \
#         --testpref {prefix / 'test'} \
#         --destdir {binpath} \
#         --joined-dictionary \
#         --workers 2
# 
# 2022-06-24 21:48:10 | INFO | fairseq_cli.preprocess | Namespace(no_progress_bar=False, log_interval=100, log_format=None, tensorboard_logdir=None, wandb_project=None, azureml_logging=False, seed=1, cpu=False, tpu=False, bf16=False, memory_efficient_bf16=False, fp16=False, memory_efficient_fp16=False, fp16_no_flatten_grads=False, fp16_init_scale=128, fp16_scale_window=None, fp16_scale_tolerance=0.0, min_loss_scale=0.0001, threshold_loss_scale=None, user_dir=None, empty_cache_freq=0, all_gather_list_size=16384, model_parallel_size=1, quantization_config_path=None, profile=False, reset_logging=False, suppress_crashes=False, criterion='cross_entropy', tokenizer=None, bpe=None, optimizer=None, lr_scheduler='fixed', scoring='bleu', task='translation', source_lang='en', target_lang='zh', trainpref='D:\\codes\\pytorch_learning\\jupyter\\8-HW5_Seq2seq\\dataset\\rawdata\\ted2020\\train', validpref='D:\\codes\\pytorch_learning\\jupyter\\8-HW5_Seq2seq\\dataset\\rawdata\\ted2020\\valid', testpref='D:\\codes\\pytorch_learning\\jupyter\\8-HW5_Seq2seq\\dataset\\rawdata\\ted2020\\test', align_suffix=None, destdir='dataset\\data-bin\\ted2020', thresholdtgt=0, thresholdsrc=0, tgtdict=None, srcdict=None, nwordstgt=-1, nwordssrc=-1, alignfile=None, dataset_impl='mmap', joined_dictionary=True, only_source=False, padding_factor=8, workers=2)
# 2022-06-24 21:49:11 | INFO | fairseq_cli.preprocess | [en] Dictionary: 8000 types
# 2022-06-24 21:50:09 | INFO | fairseq_cli.preprocess | [en] D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\train.en: 390060 sents, 12207217 tokens, 0.0% replaced by <unk>
# 2022-06-24 21:50:09 | INFO | fairseq_cli.preprocess | [en] Dictionary: 8000 types
# 2022-06-24 21:50:13 | INFO | fairseq_cli.preprocess | [en] D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\valid.en: 3940 sents, 122292 tokens, 0.0% replaced by <unk>
# 2022-06-24 21:50:13 | INFO | fairseq_cli.preprocess | [en] Dictionary: 8000 types
# 2022-06-24 21:50:17 | INFO | fairseq_cli.preprocess | [en] D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\test.en: 4000 sents, 122808 tokens, 0.0% replaced by <unk>
# 2022-06-24 21:50:17 | INFO | fairseq_cli.preprocess | [zh] Dictionary: 8000 types
# 2022-06-24 21:51:03 | INFO | fairseq_cli.preprocess | [zh] D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\train.zh: 390060 sents, 9321522 tokens, 0.0% replaced by <unk>
# 2022-06-24 21:51:03 | INFO | fairseq_cli.preprocess | [zh] Dictionary: 8000 types
# 2022-06-24 21:51:07 | INFO | fairseq_cli.preprocess | [zh] D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\valid.zh: 3940 sents, 93085 tokens, 0.00537% replaced by <unk>
# 2022-06-24 21:51:07 | INFO | fairseq_cli.preprocess | [zh] Dictionary: 8000 types
# 2022-06-24 21:51:11 | INFO | fairseq_cli.preprocess | [zh] D:\codes\pytorch_learning\jupyter\8-HW5_Seq2seq\dataset\rawdata\ted2020\test.zh: 4000 sents, 8000 tokens, 0.0% replaced by <unk>
# 2022-06-24 21:51:11 | INFO | fairseq_cli.preprocess | Wrote preprocessed data to dataset\data-bin\ted2020

### Configuration for Experiments

In [15]:
config = Namespace(                               # argparse.Namespace
    datadir='./dataset/data-bin/ted2020', 
    savedir='./checkpoints/rnn', 
    source_lang='en', 
    target_lang='zh',
    
    # CPU threads when fetching & processing data.
    num_works=2, 
    
    # Batch size in terms of tokens. Gradient accumulation increases the effective batchsize.
    max_tokens=8192, 
    accum_steps=2, 
    
    # The lr calculated from Noam lr scheduler. You can tune the maximum lr by this factor.
    lr_factor=2, 
    lr_warmth=4000, 
    
    # Clip gradient norm helps alleviate gradient exploding
    clip_norm=1.0,
    
    # Maximum epochs for training
    max_epoch=30, 
    start_epoch=1,
    
    # Beam size for beam search
    beam=5, 
    
    # Generate sequences of maximum length ax + b, where x is the source length
    max_len_a=1.2, 
    max_len_b=10, 
    # When decoding, post process sentence by removing sentencepiece symbols and jieba tokenization.
    post_process='sentencepiece', 
    
    # Checkpoints
    keep_last_epochs=5, 
    resume=None,        # if resume from checkpoint name (under config.savedir)
    
    # logging
    use_wandb=False, 
)

### Logging

- logging package logs ordinary messages.
- Wandb logs the loss, bleu, etc. in the training process.

In [16]:
logging.basicConfig(
    format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 
    datefmt='%Y-%m-%d %H:%M:%S', 
    level='INFO',     # 'DEBUG' 'WARNING' 'ERROR'
    stream=sys.stdout
)

proj = 'hw5.seq2seq'
logger = logging.getLogger(proj)
if config.use_wandb:
    import wandb    # wandb is Weight & Bias, a stronger visualization platform than Tensorboard
    wandb.init(project=proj, name=Path(config.savedir).stem, config=config)

### CUDA Environment

In [17]:
# cuda_env = utils.CudaEnvironment()
# utils.CudaEnvironment.pretty_print_cuda_env_list([cuda_env])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Dataloading

### We Borrow the Translation Task from fairseq

- Used to load the binarized data created above
- Well-implemented data iterator (dataloader)
- Built-in task.source_dictionary and task.target_dictionary are also handy
- Well-implemented beach search decoder

In [18]:
from fairseq.tasks.translation import TranslationConfig, TranslationTask

## Setup task
task_cfg = TranslationConfig(
    data=config.datadir, 
    source_lang=config.source_lang, 
    target_lang=config.target_lang, 
    train_subset='train', 
    required_seq_len_multiple=8, 
    dataset_impl='mmap', 
    # unsample_primary=1, 
)
task = TranslationTask.setup_task(task_cfg)

2022-06-27 13:51:21 | INFO | fairseq.tasks.translation | [en] dictionary: 8000 types
2022-06-27 13:51:21 | INFO | fairseq.tasks.translation | [zh] dictionary: 8000 types


In [19]:
logger.info('Loading data for epoch 1.')
task.load_dataset(split='train', epoch=1, combine=True)    # Combine if you have back-translation data.
task.load_dataset(split='valid', epoch=1)

2022-06-27 13:51:21 | INFO | hw5.seq2seq | Loading data for epoch 1.
2022-06-27 13:51:21 | INFO | fairseq.data.data_utils | loaded 390,060 examples from: ./dataset/data-bin/ted2020\train.en-zh.en
2022-06-27 13:51:21 | INFO | fairseq.data.data_utils | loaded 390,060 examples from: ./dataset/data-bin/ted2020\train.en-zh.zh
2022-06-27 13:51:21 | INFO | fairseq.tasks.translation | ./dataset/data-bin/ted2020 train en-zh 390060 examples
2022-06-27 13:51:22 | INFO | fairseq.data.data_utils | loaded 3,940 examples from: ./dataset/data-bin/ted2020\valid.en-zh.en
2022-06-27 13:51:22 | INFO | fairseq.data.data_utils | loaded 3,940 examples from: ./dataset/data-bin/ted2020\valid.en-zh.zh
2022-06-27 13:51:22 | INFO | fairseq.tasks.translation | ./dataset/data-bin/ted2020 valid en-zh 3940 examples


In [20]:
sample = task.dataset('valid')[2]
pprint.pprint(sample)
pprint.pprint(
    'Source: ' + \
    task.source_dictionary.string(
        sample['source'], 
        config.post_process, 
    )
)

pprint.pprint(
    'Target: ' + \
    task.target_dictionary.string(
        sample['target'], 
        config.post_process,
    )
)

{'id': 2,
 'source': tensor([  20,   14,    5,   86,  373, 2019,   61,  205, 2638,    8,  268,   45,
         205, 2638,  659,   23,  461,  158,  138,    4,  102,  742,   23,   13,
         271,    5,    4,  102, 2662,   23,   13,  121,  696,  513,    7,    2]),
 'target': tensor([   6,  566, 1057,  308, 3559, 2955, 3186, 1700,  633,  134,    4, 1315,
        1379,  524, 3188,  898,    4,  298, 7201, 4498, 4646, 3451, 2362, 3559,
        2955,   94,  660, 1013,  570,    9, 1323,   10,    2])}
("Source: it's not something everyone can get behind the way they get behind "
 'helping haiti , or ending aids , or fighting a famine .')
'Target: 人們不能像拖延援助海地 ,抗擊愛滋病 ,或賑濟饑荒那樣拖延對原住民的幫助 。'


### Dataset Iterator

- Controls every **batch** to contain no more than N tokens, which **optimizes GPU memory efficiency**.
- **Shuffles** the training set for every epoch
- **Ignore** sentences exceeding maximum length
- **Pad** all sentences in a batch to the same length, which enables parallel computing by GPU
- Add **EOS** and shift one token
    - **Teacher forcing**: to train the model to predict the next token based on prefix, we feed the right shifted target sequence as the decoder input.
    - Generally, prepending **BOS** to the target would do the job (as shown below).
![seq2seq](https://i.imgur.com/0zeDyuI.png)
    - In fairseq however, this is done by moving the EOS token to the beginning. Empirically, this has the same effect. For instance: 
```
# output target (target) and Decoder input (prev_output_tokens):
    EOS = 2
    target = 419, 711, 238, 888, 792, 60, 968, 8, 2
    prev_output_token = 2, 419, 711, 238, 888, 792, 60, 968, 8
```

In [21]:
def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=1, cached=True, seed=seed):
    batch_iterator = task.get_batch_iterator(
        dataset = task.dataset(split), 
        max_tokens=max_tokens, 
        max_sentences=None, 
        max_positions=utils.resolve_max_positions(
            task.max_positions(), 
            max_tokens, 
        ), 
        ignore_invalid_inputs=True, 
        seed=seed, 
        num_workers=num_workers, 
        epoch=epoch, 
        disable_iterator_cache=not cached, 
        # Set this to False to speed up. However, if set to False, changing max_tokens beyond first call of this method has no effect.
    )
    return batch_iterator

demo_epoch_obj = load_data_iterator(task, 'valid', epoch=1, max_tokens=20, num_workers=1, cached=False, seed=seed)
demo_iter = demo_epoch_obj.next_epoch_itr(shuffle=True)
sample = next(demo_iter)
sample



{'id': tensor([1703,  747]),
 'nsentences': 2,
 'ntokens': 20,
 'net_input': {'src_tokens': tensor([[   1,    1,    1,    1,    1,    1,    1,   53,    8,  818,  453,  830,
             16,   57,    7,    2],
          [   1,    1,    1,    1,    1,    1,    1, 2436,   12,  479,  967,   17,
            933,    5,    7,    2]]),
  'src_lengths': tensor([9, 9]),
  'prev_output_tokens': tensor([[   2,  229,  156, 1953,  103,  624,  565, 2582, 5759,   10,    1,    1,
              1,    1,    1,    1],
          [   2,    6, 1411,  108,   74, 1610,   10, 2412, 1160, 1730,    1,    1,
              1,    1,    1,    1]])},
 'target': tensor([[ 229,  156, 1953,  103,  624,  565, 2582, 5759,   10,    2,    1,    1,
             1,    1,    1,    1],
         [   6, 1411,  108,   74, 1610,   10, 2412, 1160, 1730,    2,    1,    1,
             1,    1,    1,    1]])}

- each batch is a python dict, with string key and Tensor value. Contents are described below:
```python
batch = {
    'id': id,    # id for each example
    'nsentences': len(samples),    # batch size (sentences)
    'ntokens': ntokens,          # batch size (tokens)
    'net_input': {
        'src_tokens': src_tokens,   # sequence in source language
        'src_lengths': src_lengths,  # sequence length of each example before padding
        'prev_output_tokens': prev_output_tokens,   # right shifted target, as mentioned above.
    }, 
    'target': target,    # target sequence
}, 
```

### Model Architecture

- We again inderit fariseq's encoder, decoder and model, so that in the testing phase we can directly leverage fairseq's beam search decoder.

In [22]:
from fairseq.models import (
    FairseqEncoder, 
    FairseqIncrementalDecoder, 
    FairseqEncoderDecoderModel
)

### Encoder

- The Encoder is a RNN or Transformer Encoder. The following description is for **RNN**. For every input token, Enocoder will generate an output vector and a hidden states vector, and the hidden states vector is passed on to the next step. In other words, **the Encoder sequentially reads in the input sequence, and outputs a single vector at each timestep, then finally outputs the final hidden states, or content vector, at the last timestep**.

- Parameters:
    - args
        - encoder_embed_dim: The dimension of embeddings, this compresses the one-hot vector into fixed dimensions, which achieves dimension reduction.
        - encoder_ffn_embed_dim: The dimension of hidden satates and output vectors
        - encoder_layers: The number of layers for Encoder RNN
        - drpout: Determines the probability of a neuron's activation being set to 0, in order to prevent overfitting. Generally this is applied in training, and removed in testing.
    - dictionary: The dictionary provided by fairseq. It's used to obtain the padding index, and in turn the encoder padding mask.
    - embed_tokens: An instance of token embeddings (nn.Embedding).
- Inputs:
    - src_tokens: integer sequence representing English (e.g. 1, 28, 29, 205, 2).
- Outputs:
    - outputs: The output of RNN at each timestep, can be further processed by Attention.
    - final hiddens: The hidden states of each timestep, will be passed to decoder for decoding.
    - encoder_padding_mask: This tells the decoder which position to ignore.

In [24]:
class RNNEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_tokens): 
        super().__init__(dictionary)
        self.embed_tokens = embed_tokens
        
        self.embed_dim = args.encoder_embed_dim
        self.hidden_dim = args.encoder_ffn_embed_dim
        self.num_layers = args.encoder_layers
        
        self.dropout_in_module = nn.Dropout(args.dropout)
        self.rnn = nn.GRU(    # GRU: Gated Recurrent Neural Network
            self.embed_dim, 
            self.hidden_dim, 
            self.num_layers, 
            dropout=args.dropout, 
            batch_first=False, 
            bidirectional=True
        )
        self.dropout_in_module= nn.Dropout(args.dropout)
        
        self.padding_idx = directionary.pad()
    
    def combine_bidir(self, outs, bsz: int): 
        out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
        return out.view(self.num_layers, bsz, -1)
    
    def forward(self, src_tokens, **unused): 
        bsz, seqlen = src_tokens.size()
        
        # Get embeddings
        x = self.embed_tokens(src_tokens)
        x = self.dropout_in_module(x)
        
        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        
        # pass thru bidirectional RNN
        h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)
        x, final_hiddens = self.rnn(x, h0)
        outputs = self.dropout_out_module(x)
        # outputs = [sequence len, batch size, hid dim * directions]
        # hidden = [num_layers * directions, batch size, hid dim]
        
        # Since Encoder is birectional, we need to concatencate the hidden states of two directions
        final_hiddens = self.combine_bidir(final_hiddens, bsz)
        # hidden = [num_layers x batch x num_directions*hidden]
        
        encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
        return tuple(
            (
                outputs,                  # seq_len x batch x hidden
                final_hiddens,            # num_layers x batch x num_directions*hidden
                encoder_padding_mask,     # seq_len x batch
            )
        )
    
    def reorder_encoder_out(self, encoder_out, new_order): 
        # This is used by fairseq's beam search. How and why is not particularly important here.
        return tuple(
            (
                encoder_out[0].index_select(1, new_order), 
                encoder_out[1].index_select(1, new_order), 
                encoder_out[2].index_select(1, new_order), 
            )
        )

### Attention

- When the input sequence is long, 'content vector' alone cannot accurately represent the whole sequence, attention mechanism can provide the Decoder more information.
- According to the **Decoder embeddings** of the current timestep, match the **Encoder outputs** with decoder embeddings to determine correlation, and then sum the Encoder outputs weighted by the correlation as the input to **Decoderz** RNN.
- Common attention implementations use neural network / dot product as the correlation between **query** (decoder embeddings) and **key** (Encoder outputs), followed by **softmax** to obtain a distribution, and finally **values** (Encoder outputs) is weighted sum-ed by said distribution.
- Parameters:
    - input_embed_dim: dimensionality of key, should be that of the vector in decoder to attend others.
    - source_embed_dim: dimensionality of query, should be that of the vector to be attended to (encoder outputs).
    - output_embed_dim: dimensionality of value, should be that of the vector after attention, expected by the next layer. 
- Inputs:
    - inputs: Is the key, the vector to attend to others.
    - encoder_outputs: Is the query/value, the vector to be attended to.
    - encoder_padding_mask: This tells the decoder which position to ignore.
- Outputs:
    - output: The context vector after attention.
    - attention score: The attention distribution.

In [27]:
class AttentionLayer(nn.Module): 
    def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False): 
        super().__init()
        
        self.input_proj = nn.Linear(input_embed_dim, source_embed_dim, bias=bias)
        self.output_proj = nn.Linear(
            input_embed_dim + source_embed_dim, output_embed_dim, bias=bias
        )
    
    def forward(self, inputs, encoder_outputs, encoder_padding_mask):
        # inputs: T, B, dim
        # encoder_outputs: S x B x dim
        # padding mask: S x B
        
        # convert all to batch first
        inputs = inputs.transpose(1, 0)    # B, T, dim
        encoder_outputs = encoder_outputs.transpose(1, 0)    # B, S, dim
        encoder_padding_mask = encoder_padding_mask.transpose(1, 0)    # B, S
        
        # Project to the dimensionality of encoder_outputs
        x = self.input_proj(inputs)
        
        # Compute attention
        # (B, T, dim) x (B, dim, S) = (B, T, S)
        attn_scores = torch.bmm(x, encoder_outputs.transpose(1, 2))    # torch.bmm() returns matrix multiplication of 2 tensors.
        
        # Cancel the attention at positions corresponding to padding
        if encoder_padding_mask is not None:
            # Leveraging broadcast  B, S -> (B, 1, S)
            encoder_padding_mask = encoder_padding_mask.unsqueeze(1)
            attn_scores = (
                attn_scores.float()
                .masked_fill_(encoder_padding_mask, float('-inf'))    # a.data.masker_fill_(mask, padding_value)    # pad a where is in line with the position of mask where mask == 1 with padding_value
                .type_as(attn_scores)
            )    # FP16 support: cast to float and back
        
        # Softmax on the dimension corresponding to source sequence
        attn_scores = F.softmax(attn_scores, dim=-1)
        
        # shape (B, T, S) x (B, S, dim) = (B, T, dim) weighted sum
        x = torch.bmm(attn_scores, encoder_outputs)
        
        # shape (B, T, dim)
        x = torch.cat((x, inputs), dim=-1)
        x = torch.tanh(self.output_proj(x))    # concat + linear + tanh
        
        # restore shape (B, T, dim) -> (T, B, dim)
        return x.transpose(1, 0), attn_scores

### Decoder

In [29]:
2.2 * 3.1415 * 25

172.78250000000003

In [30]:
0.05 * 25

1.25

### Reference

1. Original source: https://github.com/ga642381/ML2021-Spring/blob/main/HW05/HW05.ipynb
2. Link to refernce [training curves](https://wandb.ai/george0828zhang/hw5.seq2seq.new).
3. Expected run time on Colab with Tesla T4

|Baseline|Details|Total Time|
|-|:-:|:-:|
|Simple|2m 15s $\times$ 30 epochs|1hr 8m|
|Medium|4m $\times$ 30 epochs | 2hr|
|Strong|8m $\times$ 30 epochs (backward)<br>+1hr (back-translation)<br> + 15m $\times$ 30 epochs (forward) | 12hr 30m|