In [11]:
import random
import time

import torch
import torchtext
from torch import nn
from torchtext.data import BucketIterator, Field
from torchtext.datasets import Multi30k

In [12]:
SRC = Field(
    tokenize="spacy",
    tokenizer_language="de",
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
)

TRG = Field(
    tokenize="spacy",
    tokenizer_language="en",
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
)

In [13]:
train_data, validation_data, test_data = Multi30k.splits(
    root="data", exts=(".de", ".en"), fields=(SRC, TRG)
)

In [14]:
print(vars(train_data.examples[0]))

{'src': ['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}


In [15]:
SRC.build_vocab(train_data, max_size=10000, min_freq=2)
TRG.build_vocab(train_data, max_size=10000, min_freq=2)

In [16]:
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, validation_iterator, test_iterator = BucketIterator.splits(
    (train_data, validation_data, test_data), batch_size=BATCH_SIZE, device=device
)

In [17]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, dropout):
        super(Encoder, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_size)

    def forward(self, x):
        embedding = self.dropout(self.embed(x))
        outputs, hidden = self.gru(embedding)
        return hidden

In [18]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, dropout):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.dropout = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim + hidden_size, hidden_size)
        self.fc = nn.Linear(embed_dim + hidden_size * 2, vocab_size)
        
    def forward(self, x, hidden, context):
        # x.shape == (128,)
        # context.shape == (1, 128, 512)
        x = x.unsqueeze(0)
        # x.shape == (1, 128)
        embedding = self.dropout(self.embed(x))
        # embedding.shape == (1, 128, 256)
        embed_context = torch.concat((embedding, context), dim=2)
        # embed_context.shape == (1, 128, 768)
        _, hidden = self.gru(embed_context, hidden)
        # hidden.shape = (1, 128, 512)
        outputs = torch.cat(
            (embedding.squeeze(0), hidden.squeeze(0), context.squeeze(0)), 
            dim=1)
        predictions = self.fc(outputs)
        return predictions, hidden

In [19]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, source, target, teacher_force_ratio=0.5):
        seq_len = target.size(0)
        batch_size = target.size(1)    
        outputs = torch.zeros(
            seq_len, batch_size, self.decoder.vocab_size
        ).to(self.device)
        
        context = self.encoder(source)
        hidden = context
        x = target[0]
        
        for t in range(1, seq_len):
            predictions, hidden = self.decoder(x, hidden, context)
            outputs[t] = predictions
            teacher_force = random.random() < teacher_force_ratio
            if teacher_force:
                x = predictions.argmax(1)
            else:
                x = target[t]
                
        return outputs

In [20]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, ENC_DROPOUT)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, DEC_DROPOUT)
model = Seq2Seq(encoder, decoder, device).to(device)