### Seq2Seq
---

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchtext.datasets import TranslationDataset, Multi30k, IWSLT
from torchtext.data import Field, BucketIterator

import spacy
import numpy as np

import random
import math
import time

In [2]:
SEED = 11747

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [4]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_de.tokenizer(text)][::-1]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [6]:
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

In [7]:
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))

downloading training.tar.gz


training.tar.gz: 100%|██████████| 1.21M/1.21M [00:01<00:00, 1.20MB/s]


downloading validation.tar.gz


validation.tar.gz: 100%|██████████| 46.3k/46.3k [00:00<00:00, 234kB/s]


downloading mmt_task1_test2016.tar.gz


mmt_task1_test2016.tar.gz: 100%|██████████| 66.2k/66.2k [00:00<00:00, 129kB/s]


In [8]:
print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")

Number of training examples: 29000
Number of validation examples: 1014
Number of testing examples: 1000


In [9]:
print(vars(train_data.examples[0]))

{'src': ['.', 'büsche', 'vieler', 'nähe', 'der', 'in', 'freien', 'im', 'sind', 'männer', 'weiße', 'junge', 'zwei'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}


In [11]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (de) vocabulary: 7855
Unique tokens in target (en) vocabulary: 5893


In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
BATCH_SIZE = 64

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

#### Simple LSTM Encoder Decoders

In [14]:
class Encoder(nn.Module):
    def __init__(self, ninp, nembed, nhid, nlayers, dropout=0.2):
        super(Encoder, self).__init__()
        self.nhid = nhid
        self.nlayers = nlayers
        self.embedding = nn.Embedding(ninp, nembed)
        self.rnn = nn.LSTM(nembed, nhid, nlayers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        # src: (s, b)
        src = self.dropout(self.embedding(src))
        # src: (s, b, nembed), hidden: (nlayers*dir, b, nhid), cell: (nlayers*dir, b, nhid)
        out, (hidden, cell) = self.rnn(src)
        return hidden, cell
    
class Decoder(nn.Module):
    def __init__(self, nout, nembed, nhid, nlayers, dropout=0.2):
        super(Decoder, self).__init__()
        self.nout = nout
        self.nhid = nhid
        self.nlayers = nlayers
        self.embedding = nn.Embedding(nout, nembed)
        self.rnn = nn.LSTM(nembed, nhid, nlayers, dropout=dropout)
        self.fc_out = nn.Linear(nhid, nout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, hidden, cell):
        """
        for decoder, we process one token at a time
        """
        # src: (b, )
        src = src.unsqueeze(0)
        src = self.dropout(self.embedding(src))
        # src: (1, b, nembed)
        out, (hidden, cell) = self.rnn(src, (hidden, cell))
        pred = self.fc_out(output.squeeze(0))
        # pred: (b, nout)
        return pred, hidden, cell