In [1]:
import unicodedata
import string
import re
import random
import numpy as np
from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
SOS_token = 0
EOS_token = 1

In [19]:
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.word2count = {}
        self.n_words = 2
    
    def add_sentence(self, sentence):
        for word in sentence.split(" "):
            self.add_word(word)
    
    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1


class Encoder(nn.Module):
    def __init(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
    
    def forward(self, x, h0):
        embedded = self.embedding(x).view(1, 1, -1)
        x = embedded
        x, h = self.gru(x, h0)
        return x, h
    
    def zero_hidden(self, device):
        return torch.zeros(1, 1, self.hidden_size, device=device)


# TODO: try training a simple seq-to-seq model
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # Input & output size are the same (since they are word embeddings)
        self.input_size = input_size
        self.output_size = input_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, self.output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x, h0):
        embedded = self.embedding(x).view(1, 1, -1)
        x = F.relu(embedded)
        x, h = self.gru(x, h0)
        # GRU output has leading batch dimension, so index x[0]
        x = self.out(x[0])
        x = self.softmax(x)
        return x
    
    def zero_hidden(self, device):
        return torch.zeros(1, 1, self.hidden_size, device=device)

    
class AttnDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, max_length, dropout_p=0.1):
        super().__init__()
        self.input_size = input_size
        self.output_size = input_size
        self.hidden_size = hidden_size
        self.dropout_p = dropout_p
        self.max_length = max_length
        
        self.embedding = nn.Embedding(self.input_size, self.hidden_size)
        # Weights for each word in sentence of max length
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(p = self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
        
    def forward(self, x, h0, yhats_encoder):
        embedded = self.embedding(x).view(1, 1, -1)
        embedded = self.dropout(embedded)
        
        x = torch.concat((embedded[0], h0[0]), dim=1)
        attn_weights = F.softmax(
            self.attn(x),
            dim=1
        )
        
        # unsqueeze(0) --> add 'batch' dimension back to
        # attention weights & add to encoder outputs
        attn_applied = torch.bmm(
            attn_weights.unsqueeze(0),
            yhats_encoder.unsqueeze(0)
        )
        
        # subscript embedded & attn_applied to emove 'batch' dimension
        x = torch.concat((embedded[0], attn_applied[0]), dim=1)
        # unsqueeze result to add 'batch' dimension back in
        x = self.attn_combine(x).unsqueeze(0)
        x = F.relu(x)
        
        x, h = self.gru(x, h0)
        # GRU output has leading batch dimension, so index x[0]
        x = self.out(x[0])
        x = F.log_softmax(x, dim=1)
        return x, h, attn_weights
    
    def zero_hidden(self, device):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [27]:
def unicode_to_ascii(s):
    """Turn a Unicode string to plain ASCII.
    """
    return "".join(
        c for c in unicodedata.normalize("NFD", s)
        if unicodedata.category(c) != "Mn"
    )


def normalize_string(s):
    """Lowercase, trim, and remove non-letter characters.
    """
    s = unicode_to_ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s


def read_langs(lang1, lang2, data_path="../data/rnn_seq_to_seq_data", reverse=False):
    print("Reading lines...")
    
    # Read file & split into lines
    lines = open(f"{data_path}/{lang1}-{lang2}.txt", encoding="utf-8").read().strip().split('\n')
    
    # Split every line into pairs and normalize
    pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
    
    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
        
    return input_lang, output_lang, pairs


def filter_pair(p, max_length, eng_prefixes):
    return len(p[0].split(" ")) < max_length and \
        len(p[1].split(" ")) < max_length and \
        p[1].startswith(eng_prefixes)


def filter_pairs(pairs, max_length, eng_prefixes):
    return [pair for pair in pairs if filter_pair(pair, max_length, eng_prefixes)]


def prepare_data(lang1, lang2, max_length, eng_prefixes, reverse=False):
    input_lang, output_lang, pairs = read_langs(lang1, lang2, reverse=reverse)
    print(f"Read {len(pairs)} sentence pairs")
    pairs = filter_pairs(pairs, max_length, eng_prefixes)
    print(f"Trimmed to {len(pairs)} sentence pairs")
    print("Counting words...")
    for pair in pairs:
        input_lang.add_sentence(pair[0])
        output_lang.add_sentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


def indexes_from_sentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(" ")]


def tensor_from_sentence(lang, sentence, device):
    indexes = indexes_from_sentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def tensor_from_pairs(pair, input_lang, output_lang, device):
    input_tensor = tensor_from_sentence(input_lang, pair[0], device)
    target_tensor = tensor_from_sentence(output_lang, pair[1], device)
    return input_tensor, target_tensor


def train(encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, input_tensor, target_tensor, max_length, device, teacher_forcing_ratio=0.5):
    h_encoder = encoder.zero_hidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_tensor.shape[0]
    target_length = target_tensor.shape[0]
    
    yhats_encoder = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0.0
    for ei in range(input_length):
        yhat_encoder, h_encoder = encoder(input_tensor[ei], h_encoder)
        yhats_encoder[ei] = yhat_encoder[0, 0]
        
    x_decoder = torch.tensor([[SOS_token]], device=device)
    h_decoder = h_encoder
    
    use_teacher_forcing = True if np.random.uniform() < teacher_forcing_ratio else False
    
    # Teacher forcing = feed target in as next input
    if use_teacher_forcing:
        for di in range(target_length):
            # Pass in full encoder output
            yhat_decoder, h_decoder, attn_decoder = decoder(x_decoder, h_decoder, yhats_encoder)
            loss += criterion(yhat_decoder, target_tensor[di])
            
            x_decoder = target_tensor[di]
    # Non-teacher forcing = feed previous output of decoder as input
    else:
        for di in range(target_length):
            yhat_decoder, h_decoder, attn_decoder = decoder(x_decoder, h_decoder, yhats_encoder) 
            loss += criterion(yhat_decoder, target_tensor[di])
            
            topv, topi = yhat_decoder.topk(1)
            # When used as input, should be detached from backprop history
            x_decoder = topi.squeeze().detach()
            if x_decoder.item() == EOS_token:
                break
    
    loss.backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item() / target_length

In [26]:
np.random.uniform()

0.37291441617784726

In [5]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

In [6]:
input_lang, output_lang, pairs = prepare_data("eng", "fra", MAX_LENGTH, eng_prefixes, reverse=True)

Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words...
Counted words:
fra 4345
eng 2803


In [50]:
random.choice(pairs)

['elle est assez grande pour voyager toute seule .',
 'she is old enough to travel by herself .']