In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchnlp.datasets import wmt_dataset, iwslt_dataset  # doctest: +SKIP

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Data

In [21]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS
        
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
            

We will be using the iwslt dataset.

In [22]:
train, valid, test = iwslt_dataset(train=True, dev=True, test=True)

In [23]:
def readLangs(dataset, reverse=False):
    """ 
    @param dataset: list of dicts
        eg. [{'en': 'aaaaa', 'de': 'bbbbb'}, ....]"""
    print("Reading lines...", end='')
    
    lang1, lang2 = dataset[0].keys()
    
    Lang1, Lang2 = Lang(lang1), Lang(lang2)
    pairs = []
    for pair in dataset:
        if not reverse:
            Lang1.addSentence(pair[lang1])
            Lang2.addSentence(pair[lang2])
            pairs.append((pair[lang1], pair[lang2]))
        else:
            Lang1.addSentence(pair[lang2])
            Lang2.addSentence(pair[lang1])
            pairs.append((pair[lang2], pair[lang1]))
    print("complete!")
    return Lang1, Lang2, pairs

In [24]:
def prepareData(dataset, reverse=False):
    source_lang, target_lang, pairs = readLangs(dataset, reverse)
    print("n words total: ")
    print("Source lang: ", source_lang.name, source_lang.n_words)
    print("Target lang: ", target_lang.name, target_lang.n_words)
    return source_lang, target_lang, pairs

In [25]:
%time source_lang, target_lang, train_pairs = prepareData(train)

Reading lines...complete!
n words total: 
Source lang:  en 123158
Target lang:  de 207319
CPU times: user 3.67 s, sys: 72.3 ms, total: 3.75 s
Wall time: 3.76 s


#### Models

In [30]:
class EncoderRNN(nn.Module):
    def __init__(self, ninp, nhid):
        super(EncoderRNN, self).__init__()
        self.nhid = nhid
        self.embedding = nn.Embedding(ninp, nhid)
        self.gru = nn.GRU(nhid, nhid)
        
    def forward(self, src, hidden):
        embed = self.embedding(src).view(1,1,-1)
        out, hidden = self.gru(embed, hidden)
        return out, hidden
    
    def initHidden(self):
        return torch.zeros(1,1,self.hidden_size, device=device)
    
    
class DecoderRNN(nn.Module):
    def __init__(self, nhid, nout):
        super(DecoderRNN, self).__init__()
        self.nhid = nhid
        self.embedding = nn.Embedding(nout, nhid)
        self.gru = nn.GRU(nhid, nhid)
        self.out = nn.Linear(nhid, nout)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, src, hidden):
        out = self.embedding(src).view(1,1,-1)
        out = F.relu(out)
        out, hidden = self.gru(out, hidden)
        out = self.softmax(self.out(out[0]))
        

class AttnDecoderRNN(nn.Module):
    def __init__(self, nhid, nout, dropout=0.2, max_length=100):
        super(AttnDecoderRNN, self).__init__()
        self.nhid = nhid
        self.nout = nout
        self.dropout = dropout
        self.max_length = max_length
        
        self.embedding = nn.Embedding(self.nout, self.nhid)
        self.attn = nn.Linear(self.nhid*2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size*2, self.nhid)
        self.dropout = nn.Dropout(self.dropout)
        self.gru = nn.GRU(self.nhid, self.nhid)
        self.out = nn.Linear(self.nhid, self.nout)
        
    def forward(self, src, hidden, encoder_outputs):
        embed = self.embedding(src).view(1,1,-1)
        embed = self.dropout(embed)
        
        attn_weights = F.softmax(self.attn(torch.cat((embed[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        
        output = torch.cat((embed[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        
        output = F.relu(output)
        out, hidden = self.gru(output, hidden)
        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights
    
    def initHidden(self):
        return torch.zeros(1,1,self.nhid, device=device)
        
        

#### Data

In [31]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indixes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_TOKEN)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1,1)

def tensorsFromPair(pair, src_lang, tgt_lang):
    src = tensorFromSentence(src_lang, pair[0])
    tgt = tensorFromSentence(tgt_lang, pair[1])
    return (src, tgt)

#### Training

In [None]:
teacher_forcing_ratio = 0.5

def train(src, tgt, encoder, decoder, encoder_optimizer, 
          decoder_optimizer, criterion, max_length=100):
    
    encoder_hidden = encoder.initHidden()
    encoder_optimizer.zero_grad()