In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.legacy.datasets import Multi30k
from torchtext.legacy.data import Field, BucketIterator

import matplotlib.pyplot as plt
import spacy
import numpy as np

from copy import deepcopy
import random
import math
import time

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Data processing
- copied from https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

In [3]:
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

In [4]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

____

# Configuration

In [5]:
src_seq_len = 30
trg_seq_len = 30-1
BATCH_SIZE = 128

SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            fix_length=src_seq_len,
            batch_first = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            fix_length=src_seq_len,
            batch_first = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     device = device)

for i in train_iterator : 
    break

____

# Load Model

In [8]:
import sys
sys.path.append("..")

from transformer.seq2seq import enc_dec

In [9]:
d_model = 256
d_ff = 512
n_head = 8
batch_size = BATCH_SIZE
src_vocab_size = len(SRC.vocab)
trg_vocab_size = len(TRG.vocab)
dropout_p = 0.1
n_enc_layer, n_dec_layer = 3,3

model = enc_dec.EncoderDecoder(src_vocab_size,
                                 trg_vocab_size,
                                 src_seq_len,
                                 trg_seq_len,
                                 SRC.vocab.stoi['<pad>'],
                                 TRG.vocab.stoi['<pad>'],
                                 d_model,
                                 d_ff,
                                 n_head,
                                 dropout_p,
                                 n_enc_layer,
                                 n_dec_layer).to(device)

___

# Train and Test

In [10]:
LEARNING_RATE = 0.0005

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = 1)

In [11]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator) :
        
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output = model(src, trg[:,:-1])                
        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg[:,:-1])            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            loss = criterion(output, trg)
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut6-model.pt')

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

In [None]:
model.load_state_dict(torch.load('tut6-model.pt'))
test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

___

# Test

In [None]:
for example_idx in range(10) : 
    for i in test_iterator : 
        break

    sent = []
    for wi in i.src[example_idx][i.src[example_idx] != 1] : 
        wi = wi.cpu().data.numpy().item()
        txt = SRC.vocab.itos[wi]
        sent.append(txt)
    print(f'src = {sent}')

    for i in test_iterator : 
        break

    sent = []
    for wi in i.trg[example_idx][i.src[example_idx] != 1] : 
        wi = wi.cpu().data.numpy().item()
        txt = TRG.vocab.itos[wi]
        sent.append(txt)
    print(f'trg = {sent}')

    model.eval()
    output = model(i.src, i.trg[:,:-1])
    predictions = output[example_idx].argmax(1)

    sent = []
    for wi in predictions : 
        wi = wi.cpu().data.numpy().item()
        if wi == TRG.vocab.stoi['<eos>'] : 
            break
        txt = TRG.vocab.itos[wi]
        sent.append(txt)
    print(f'pred = {sent}')

    print("#"*100)