In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.legacy.datasets import Multi30k
from torchtext.legacy.data import Field, BucketIterator

import matplotlib.pyplot as plt
import spacy
import numpy as np

from copy import deepcopy
import random
import math
import time

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Data processing
- copied from https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

In [3]:
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

In [4]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

____

# Configuration

In [5]:
src_seq_len = 30
trg_seq_len = 30-1
BATCH_SIZE = 128

SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            fix_length=src_seq_len,
            batch_first = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            fix_length=src_seq_len,
            batch_first = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     device = device)

for i in train_iterator : 
    break

# Input Embedding

In [6]:
class InputEmbedding(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(seq_length, d_model)
    
    def generate_enc_mask_m(self, src) :       
        mask_m = (src != 1).unsqueeze(1).unsqueeze(2)
        return mask_m

    def generate_dec_mask_m(self, trg) :
        trg_pad_mask = (trg != 1).unsqueeze(1).unsqueeze(2)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len)), diagonal=0).bool().to(trg_pad_mask.device)
        mask_m = trg_pad_mask & trg_sub_mask
        return mask_m
    
    def forward(self, x) : 
        emb = self.tok_emb(x)
        pos = torch.arange(0, emb.shape[1]).unsqueeze(0).repeat(emb.shape[0], 1).to(emb.device)
        summed = emb / math.sqrt(self.d_model) + self.pos_emb(pos)
        return summed

In [29]:
for i in train_iterator : 
    break

# Scaled Dot-Product Attention

In [7]:
class ScaledDotProductAttention(nn.Module) : 
    def __init__(self, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask) :         
        score = torch.matmul(q, k.permute(0,1,3,2).contiguous())/math.sqrt(d_model)
        score = score.masked_fill(mask==0, -1e10)
        scaled_score = torch.softmax(score, dim=-1)
        
        attention = torch.matmul(scaled_score, v).permute(0,2,3,1).contiguous()
        attention = attention.view(attention.shape[0], attention.shape[1], self.d_model)
        
        return self.fc(attention)

___

# Multi-Head Attention

In [8]:
class MultiHeadAttention(nn.Module) : 
    def __init__(self, d_model, seq_length, n_head) : 
        super().__init__()
        assert d_model % n_head == 0, f"n_head({n_head}) does not divide d_model({d_model})"

        self.n_div_head = d_model//n_head
        self.d_model = d_model
        self.seq_len = seq_length
        self.n_head = n_head

        self.Q = nn.Linear(d_model,  d_model)
        self.K = nn.Linear(d_model,  d_model)
        self.V = nn.Linear(d_model,  d_model)
        
    def div_and_sort_for_multiheads(self, projected, seq_len) : 
        div = projected.view(projected.shape[0], self.n_head, seq_len, self.n_div_head)
        return div
    
    def forward(self, emb, enc_inputs=None) :
        q = self.div_and_sort_for_multiheads(self.Q(emb), self.seq_len)
    
        if enc_inputs is not None : # enc-dec attention
            seq_len = enc_inputs.shape[1] # takes target sequence length for k and v
            k = self.div_and_sort_for_multiheads(self.K(enc_inputs), seq_len)
            v = self.div_and_sort_for_multiheads(self.V(enc_inputs), seq_len)
        else : # self-attention
            k = self.div_and_sort_for_multiheads(self.K(emb), self.seq_len)
            v = self.div_and_sort_for_multiheads(self.V(emb), self.seq_len)

        return q,k,v

# Post-process the sub-layer
- layer normalization
- residual conection
- residual dropout

In [9]:
class PostProcessing(nn.Module) : 
    def __init__(self, d_model, p=0.1) : 
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p)
        
    def forward(self, emb, attn) : 
        return self.ln(emb+self.dropout(attn))

# Position-wise FFN

In [10]:
class PositionwiseFFN(nn.Module) : 
    def __init__(self, d_model, d_ff) : 
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x) : 
        return self.fc2(torch.relu(self.fc1(x)))

# Encoder

In [11]:
class EncoderLayer(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model, d_ff, n_head, dropout_p) : 
        super().__init__()
        
        self.ma = MultiHeadAttention(d_model, seq_length, n_head).to(device)
        self.sdp = ScaledDotProductAttention(d_model)
        
        self.pp1 = PostProcessing(d_model, dropout_p)
        self.pp2 = PostProcessing(d_model, dropout_p)
        
        self.positionwise_ffn = PositionwiseFFN(d_model, d_ff)
            
    def forward(self, emb, mask_m) :

        q,k,v = self.ma(emb)    
        attn = self.sdp(q,k,v, mask=mask_m)
        
        attn = self.pp1(emb, attn)
        z = self.positionwise_ffn(attn)

        return self.pp2(attn, z)

# Decoder

In [12]:
class DecoderLayer(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model, d_ff, n_head, dropout_p) : 
        super().__init__()
        
        self.seq_length = seq_length
        
        self.ma_self = MultiHeadAttention(d_model, seq_length, n_head).to(device)
        self.ma_enc_dec = MultiHeadAttention(d_model, seq_length, n_head).to(device)
        self.sdp_self = ScaledDotProductAttention(d_model)
        self.sdp_enc_dec = ScaledDotProductAttention(d_model)
        
        self.pp1 = PostProcessing(d_model, dropout_p)
        self.pp2 = PostProcessing(d_model, dropout_p)
        self.pp3 = PostProcessing(d_model, dropout_p)

        self.positionwise_ffn = PositionwiseFFN(d_model, d_ff)
    
    def forward(self, emb, mask_m_src, mask_m_trg, enc_hidden) : 
        
        q,k,v = self.ma_self(emb)
        attn = self.sdp_self(q,k,v, mask=mask_m_trg)
        attn1 = self.pp1(emb, attn)
        
        dec_q,enc_k,enc_v = self.ma_enc_dec(attn1, enc_hidden)
        attn2 = self.sdp_enc_dec(dec_q, enc_k, enc_v, mask_m_src)
        sub_layer_output = self.pp2(attn1, attn2)

        z = self.positionwise_ffn(sub_layer_output)

        return self.pp3(sub_layer_output, z)

# Encoder-Decoder

In [20]:
class EncoderDecoder(nn.Module) : 
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 src_seq_length,
                 trg_seq_length,
                 d_model,
                 d_ff,
                 n_head,
                 dropout_p,
                 n_enc_layer,
                 n_dec_layer) : 
        
        super().__init__()
        
        self.src_embber = InputEmbedding(src_vocab_size, src_seq_length, d_model)
        self.trg_embber = InputEmbedding(trg_vocab_size, trg_seq_length, d_model)
        
        enc = EncoderLayer(src_vocab_size, src_seq_length, d_model, d_ff, n_head, dropout_p)
        dec = DecoderLayer(trg_vocab_size, trg_seq_length, d_model, d_ff, n_head, dropout_p)
        
        self.enc = nn.ModuleList([deepcopy(enc) for _ in range(n_enc_layer)])
        self.dec = nn.ModuleList([deepcopy(dec) for _ in range(n_dec_layer)])
        
        self.fc = nn.Linear(d_model, trg_vocab_size)
        
    def forward(self, src, trg) : 
        
        src_emb, trg_emb = self.src_embber(src), self.trg_embber(trg)
        src_mask_m = self.src_embber.generate_enc_mask_m(src)
        trg_mask_m = self.trg_embber.generate_dec_mask_m(trg)
        
        for enc_layer in self.enc : 
            src_emb = enc_layer(src_emb, src_mask_m)
        
        for dec_layer in self.dec : 
            trg_emb = dec_layer(trg_emb, src_mask_m, trg_mask_m, src_emb)
        
        return self.fc(trg_emb)

In [43]:
d_model = 256
d_ff = 512
n_head = 8
batch_size = BATCH_SIZE
src_vocab_size = len(SRC.vocab)
trg_vocab_size = len(TRG.vocab)
dropout_p = 0.1
n_enc_layer, n_dec_layer = 3,3

In [44]:
model = EncoderDecoder(src_vocab_size,
                         trg_vocab_size,
                         src_seq_len,
                         trg_seq_len,
                         d_model,
                         d_ff,
                         n_head,
                         dropout_p,
                         n_enc_layer,
                         n_dec_layer).to(device)

____

# Train and Test

In [45]:
LEARNING_RATE = 0.0005

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = 1)

In [46]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator) :
        
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output = model(src, trg[:,:-1])                
        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg[:,:-1])            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            loss = criterion(output, trg)
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [47]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut6-model.pt')

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 0m 31s
	Train Loss: 4.867 | Train PPL: 129.942
	 Val. Loss: 4.400 |  Val. PPL:  81.477
Epoch: 02 | Time: 0m 31s
	Train Loss: 4.331 | Train PPL:  76.014
	 Val. Loss: 3.822 |  Val. PPL:  45.709
Epoch: 03 | Time: 0m 32s
	Train Loss: 3.560 | Train PPL:  35.158
	 Val. Loss: 3.117 |  Val. PPL:  22.573
Epoch: 04 | Time: 0m 32s
	Train Loss: 3.045 | Train PPL:  21.003
	 Val. Loss: 2.754 |  Val. PPL:  15.710
Epoch: 05 | Time: 0m 32s
	Train Loss: 2.763 | Train PPL:  15.848
	 Val. Loss: 2.521 |  Val. PPL:  12.439
Epoch: 06 | Time: 0m 32s
	Train Loss: 2.560 | Train PPL:  12.940
	 Val. Loss: 2.362 |  Val. PPL:  10.609
Epoch: 07 | Time: 0m 32s
	Train Loss: 2.416 | Train PPL:  11.196
	 Val. Loss: 2.171 |  Val. PPL:   8.765
Epoch: 08 | Time: 0m 32s
	Train Loss: 2.266 | Train PPL:   9.642
	 Val. Loss: 2.165 |  Val. PPL:   8.712
Epoch: 09 | Time: 0m 32s
	Train Loss: 2.143 | Train PPL:   8.525
	 Val. Loss: 1.971 |  Val. PPL:   7.181
Epoch: 10 | Time: 0m 32s
	Train Loss: 1.999 | Train PPL

In [48]:
model.load_state_dict(torch.load('tut6-model.pt'))
test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

| Test Loss: 1.812 | Test PPL:   6.125 |


___

# Test

In [49]:
for example_idx in range(10) : 
    for i in test_iterator : 
        break

    sent = []
    for wi in i.src[example_idx][i.src[example_idx] != 1] : 
        wi = wi.cpu().data.numpy().item()
        txt = SRC.vocab.itos[wi]
        sent.append(txt)
    print(f'src = {sent}')

    for i in test_iterator : 
        break

    sent = []
    for wi in i.trg[example_idx][i.src[example_idx] != 1] : 
        wi = wi.cpu().data.numpy().item()
        txt = TRG.vocab.itos[wi]
        sent.append(txt)
    print(f'trg = {sent}')

    model.eval()
    output = model(i.src, i.trg[:,:-1])
    predictions = output[example_idx].argmax(1)

    sent = []
    for wi in predictions : 
        wi = wi.cpu().data.numpy().item()
        if wi == TRG.vocab.stoi['<eos>'] : 
            break
        txt = TRG.vocab.itos[wi]
        sent.append(txt)
    print(f'pred = {sent}')

    print("#"*100)

src = ['<sos>', 'zwei', 'mittelgroße', 'hunde', 'laufen', 'über', 'den', 'schnee', '.', '<eos>']
trg = ['<sos>', 'two', 'medium', 'sized', 'dogs', 'run', 'across', 'the', 'snow', '.']
pred = ['two', 'scuba', 'martial', 'dogs', 'up', 'across', 'the', 'snow', '.']
####################################################################################################
src = ['<sos>', 'vier', 'personen', 'spielen', 'fußball', 'auf', 'einem', 'strand', '.', '<eos>']
trg = ['<sos>', 'four', 'people', 'are', 'playing', 'soccer', 'on', 'a', 'beach', '.']
pred = ['four', 'people', 'are', 'playing', 'water', 'on', 'a', 'beach', '.']
####################################################################################################
src = ['<sos>', 'ein', 'junge', 'fährt', 'skateboard', 'auf', 'einer', 'skateboardrampe', '.', '<eos>']
trg = ['<sos>', 'a', 'boy', 'riding', 'a', 'skateboard', 'on', 'a', 'skateboarding', 'ramp']
pred = ['a', 'boy', 'playing', 'a', 'skateboard', 'on', 'a', 'playground', 