# **Implement Seq2Seq from scratch**

---
Model: Seq2Seq Bi-GRU with Attention

Dataset: Huggingface's mt_en_vi


### Create model

In [None]:
import torch.nn as nn
import torch
import random

In [None]:
class RNNEncoder(nn.Module):
    def __init__(self, source_vocab_size, emb_size=100, hidden_size=512, num_layers=2, dropout_ratio=0.2, bidirectional=True):
        super().__init__()
        self.src_vocab_size = source_vocab_size
        self.hidden_size = hidden_size // 2 if bidirectional else hidden_size
        self.n_layers = num_layers
        self.n_directions = 2 if bidirectional else 1

        self.dropout = nn.Dropout(p=dropout_ratio)
        self.embedding = nn.Embedding(num_embeddings=source_vocab_size, embedding_dim=emb_size)
        self.gru = nn.GRU(input_size=emb_size, hidden_size=self.hidden_size, num_layers=num_layers,
                          bidirectional=bidirectional, dropout=dropout_ratio)
    
    def forward(self, inputs):
        # inputs: [max_input_length, bs]

        emb = self.dropout(self.embedding(inputs))
        out, hid = self.gru(emb)

        if self.n_directions == 2:
            hid = hid.view(self.n_layers, self.n_directions, -1, self.hidden_size)
            hid = torch.cat((hid[:, 0, :, :], hid[:, 1, :, :]), dim=2)
        
        return out, hid

    def load_pretrained_embedding(self):
        pass

Attention layer (Luong's version)

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_size, n_layers):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.alignment = nn.Linear(n_layers*hidden_size, hidden_size, bias=False)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, de_hid_state, en_outputs):
        # decoder hidden state: [n_layers, bs, hidden_size]
        # encoder outputs: [MAX_INPUT_LENGTH, bs, hidden_size]

        dec = de_hid_state.permute(1, 0, 2)
        bs = dec.shape[0]
        dec = dec.reshape(bs, 1, -1)
        # [bs, 1, n_layers*hid_size] --> [bs, 1, hid_size]
        dec = self.alignment(dec)

        enc = en_outputs.permute(1, 2, 0)
        # [bs, 1, hid_size] x [bs, hid_size, maxlen] -> [bs, 1, maxlen]
        score = torch.bmm(dec, enc)
        return self.softmax(score)

In [None]:
class AttRNNDecoder(nn.Module):
    def __init__(self, target_vocab_size, emb_size=100, hidden_size=512, num_layers=2, dropout_ratio=0.2):
        super().__init__()
        self.trg_vocab_size = target_vocab_size
        self.hidden_size = hidden_size
        self.n_layers = num_layers

        self.dropout = nn.Dropout(p=dropout_ratio)
        self.embedding = nn.Embedding(num_embeddings=target_vocab_size, embedding_dim=emb_size)
        self.gru = nn.GRU(input_size=emb_size, hidden_size=hidden_size, 
                          num_layers=num_layers, dropout=dropout_ratio)
        self.attention = Attention(hidden_size, num_layers)
        self.fc1 = nn.Linear(2*hidden_size, hidden_size)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_size, target_vocab_size)
    
    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(0)
        # input: [1, bs]
        # hidden: [n_layers, bs, hidden_size]
        # encoder_output: [MAX_INPUT_LENGTH, bs, hidden_size]

        emb = self.dropout(self.embedding(input))
        out, hid = self.gru(emb, hidden)

        att_weights = self.attention(hid, encoder_outputs)
        enc = encoder_outputs.permute(1, 0, 2)

        # [bs, 1, maxlen] x [bs, maxlen, hid_size] -> [bs, hid_size]
        context = torch.bmm(att_weights, enc).squeeze(1)
        out = out.squeeze(0)
        
        # [bs, 2*hid_size] --> [bs, hid_size]
        pred = self.fc1(torch.cat((context, out), dim=1))
        pred = self.gelu(pred)

        # [bs, hid_size] --> [bs, target_vocab_size]
        pred = self.fc2(pred)
        
        return pred, hid

    def load_pretrained_embedding(self):
        pass

In [None]:
class AttSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.bidirectional_encoder = True if encoder.n_directions == 2 else False
        
    def forward(self, src, trg, teacher_forcing_ratio=0):
        
        # src: [max_input_length, bs]
        # trg: [max_output_length, bs]
        
        batch_size = trg.shape[1]
        max_output_len = trg.shape[0]
        trg_vocab_size = self.decoder.trg_vocab_size
        
        # tensor to store decoder outputs
        preds = torch.zeros(max_output_len, batch_size, trg_vocab_size).to(DEVICE)
    
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        encoder_outputs, hidden = self.encoder(src)
                
        # first input to the decoder is the <sos> tokens
        input = trg[0]
        
        for t in range(1, max_output_len):         
            pred, hidden = self.decoder(input, hidden, encoder_outputs)
            preds[t] = pred
            teacher_force = random.random() < teacher_forcing_ratio
            best_pred = pred.argmax(1) 
            input = trg[t] if teacher_force else best_pred
        
        return preds

### Prepare data

Load dataset

In [None]:
!pip install datasets
from datasets import load_dataset
hf_dataset = load_dataset('mt_eng_vietnamese', 'iwslt2015-vi-en')

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/46/1a/b9f9b3bfef624686ae81c070f0a6bb635047b17cdb3698c7ad01281e6f9a/datasets-1.6.2-py3-none-any.whl (221kB)
[K     |█▌                              | 10kB 21.1MB/s eta 0:00:01[K     |███                             | 20kB 28.4MB/s eta 0:00:01[K     |████▍                           | 30kB 33.6MB/s eta 0:00:01[K     |██████                          | 40kB 29.4MB/s eta 0:00:01[K     |███████▍                        | 51kB 28.2MB/s eta 0:00:01[K     |████████▉                       | 61kB 30.5MB/s eta 0:00:01[K     |██████████▍                     | 71kB 28.2MB/s eta 0:00:01[K     |███████████▉                    | 81kB 29.1MB/s eta 0:00:01[K     |█████████████▎                  | 92kB 30.1MB/s eta 0:00:01[K     |██████████████▊                 | 102kB 31.4MB/s eta 0:00:01[K     |████████████████▎               | 112kB 31.4MB/s eta 0:00:01[K     |█████████████████▊              | 122kB 31

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1884.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1085.0, style=ProgressStyle(description…


Downloading and preparing dataset mt_eng_vietnamese/iwslt2015-vi-en (download: 30.83 MiB, generated: 31.59 MiB, post-processed: Unknown size, total: 62.42 MiB) to /root/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-vi-en/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=18074646.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=13603614.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=188396.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=140250.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=183855.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=132264.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset mt_eng_vietnamese downloaded and prepared to /root/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-vi-en/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71. Subsequent calls will reuse this data.


Preprocess data

In [None]:
# Import hf's tokenizer
!pip install transformers
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-multilingual-cased')

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/b0/9e/5b80becd952d5f7250eaf8fc64b957077b12ccfe73e9c03d37146ab29712/transformers-4.6.0-py3-none-any.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 24.6MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 49.8MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 46.6MB/s 
Installing collected packages: sacremoses, tokenizers, transformers
Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.6.0


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1961828.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




In [None]:
# Create iterator through a data set. Convert text into tensors
# Returns list of batches, each batch: {'src': [max_input_length, bs], 'trg': [max_output_length, bs]}

def make_iterator(dataset, batch_size):
    n_examples = len(dataset)
    random.shuffle(dataset)
    iterator = []
    for i in range(n_examples // batch_size):
        src_texts = []
        trg_texts = []
        for j in range(batch_size):
            src_texts.append(dataset[batch_size*i+j]['en'])
            trg_texts.append(dataset[batch_size*i+j]['vi'])
        src_tensors = tokenizer(src_texts, padding='max_length', max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors='pt')['input_ids'].permute(1, 0)
        trg_tensors = tokenizer(trg_texts, padding='max_length', max_length=MAX_OUTPUT_LENGTH, truncation=True, return_tensors='pt')['input_ids'].permute(1, 0)
        new_batch = {'src': src_tensors, 'trg': trg_tensors}
        iterator.append(new_batch)
    return iterator

### Training

In [None]:
def train(model, iterator, criterion, optimizer):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch['src'].to(DEVICE)
        trg = batch['trg'].to(DEVICE)

        optimizer.zero_grad()
        
        output = model(src, trg, TEACHER_FORCING_RATIO)

        
        # src: [max_input_length, bs]
        # trg: [max_output_length, bs]
        # output: [max_output_length, bs, trg_vocab_size]
        
        trg_vocab_size = output.shape[-1]
        
        output = output[1:].reshape(-1, trg_vocab_size)
        trg = trg[1:].reshape(-1)
        
        # trg = [(trg len - 1) * batch size]
        # output = [(trg len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        loss.backward()
                
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch['src'].to(DEVICE)
            trg = batch['trg'].to(DEVICE)

            output = model(src, trg)

            # src: [max_input_length, bs]
            # trg: [max_output_length, bs]
            # output: [max_output_length, bs, trg_vocab_size]
            
            trg_vocab_size = output.shape[-1]
            
            output = output[1:].reshape(-1, trg_vocab_size)
            trg = trg[1:].reshape(-1)
            
            # trg = [(trg len - 1) * batch size]
            # output = [(trg len - 1) * batch size, output dim]
            
            loss = criterion(output, trg)            
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
import time
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Training hyperparams
BATCH_SIZE = 64
MAX_INPUT_LENGTH = 32
MAX_OUTPUT_LENGTH = 64
NUM_EPOCHS = 5
LEARNING_RATE = 0.0001
TEACHER_FORCING_RATIO = 0.2

# Model hyperparams
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
VOCAB_SIZE = tokenizer.vocab_size
encoder = RNNEncoder(VOCAB_SIZE).to(DEVICE)
decoder = AttRNNDecoder(VOCAB_SIZE).to(DEVICE)
seq2seq = AttSeq2Seq(encoder, decoder).to(DEVICE)

loss_function = nn.CrossEntropyLoss(ignore_index=0)     # ignore [PAD] token
optim = torch.optim.Adam(seq2seq.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model and save the best checkpoint
best_valid_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    
    # Generate train_iterator, valid_iterator
    train_iterator = make_iterator(hf_dataset['train'][:]['translation'], BATCH_SIZE)
    valid_iterator = make_iterator(hf_dataset['validation'][:]['translation'], BATCH_SIZE)

    start_time = time.time()
    
    train_loss = train(seq2seq, train_iterator, loss_function, optim)
    valid_loss = evaluate(seq2seq, valid_iterator, loss_function)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(seq2seq.state_dict(), f'att-seq2seq-{best_valid_loss:.3f}.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

Epoch: 01 | Time: 94m 28s
	Train Loss: 6.166
	 Val. Loss: 5.877
Epoch: 02 | Time: 94m 20s
	Train Loss: 5.772
	 Val. Loss: 5.718


### Inference

In [None]:
# Load model at best checkpoint
seq2seq.load_state_dict(torch.load('att-seq2seq.pt'))

In [None]:
# Compute loss on test set
test_iterator = make_iterator(hf_dataset['test'][:]['translation'], BATCH_SIZE)
test_loss = evaluate(seq2seq, test_iterator, loss_function)
print(f'\tTest Loss: {test_loss:.3f}')