In [269]:
### Imports ###

import matplotlib.pyplot as plt
plt.switch_backend('TkAgg')
import matplotlib.ticker as ticker
import numpy as np
import random

img_path = 'imgs/'

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [126]:
SOW_token = 0
EOW_token = 1

class Chars:
    def __init__(self, name):
        self.name = name
        self.char2index = {}
        self.char2count = {}
        self.index2char = {0: "SOW", 1: "EOW"}
        self.n_chars = 2

    def addWord(self, word):
        for char in word:
            self.addChar(char)

    def addChar(self, char):
        if char not in self.char2index:
            self.char2index[char] = self.n_chars
            self.char2count[char] = 1
            self.index2char[self.n_chars] = char
            self.n_chars += 1
        else:
            self.char2count[char] += 1

In [127]:
def readChars(lang1, lang2, reverse=False):
    print("Reading lines...")
    #lang1 = 'spa'
    #lang2 = 'fre'
    path = lang1 + '_' + lang2
    data_path = 'data/'+path+'.txt'

    f = open(data_path, 'r')
    source_words = []
    target_words = []
    pairs = []

    for line in f:
        word = line.strip().split('$')
        pairs.append(word)
        source_words.append(word[0])
        target_words.append(word[1])

    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Chars(lang2)
        output_lang = Chars(lang1)
    else:
        input_lang = Chars(lang1)
        output_lang = Chars(lang2)

    return input_lang, output_lang, pairs

In [128]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readChars(lang1, lang2, reverse)
    print("Read {} sentence pairs".format(len(pairs)))
    print("Counting chars...")
    for pair in pairs:
        input_lang.addWord(pair[0])
        output_lang.addWord(pair[1])
    print("Counted chars:")
    print(input_lang.name, input_lang.n_chars)
    print(output_lang.name, output_lang.n_chars)
    return input_lang, output_lang, pairs

In [129]:
input_lang, output_lang, pairs = prepareData('spa', 'fre')
print(random.choice(pairs))

Reading lines...
Read 3369 sentence pairs
Counting chars...
Counted chars:
spa 31
fre 39
['ɾeaktibos', 'ʁeaktif']


In [130]:
MAX_LENGTH_INPUT = max(len(pair[0]) for pair in pairs)
MAX_LENGTH_OUTPUT = max(len(pair[1]) for pair in pairs)
MAX_LENGTH = max(MAX_LENGTH_INPUT, MAX_LENGTH_OUTPUT)+1

In [131]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

In [132]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOW_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden  = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop

    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

In [133]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOW_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                decoder_input = target_tensor[:, i].unsqueeze(1) 
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

In [252]:
def indexesFromWord(lang, word):
    return [lang.char2index[char] for char in word]

def tensorFromWord(lang, word):
    indexes = indexesFromWord(lang, word)
    indexes.append(EOW_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromWord(input_lang, pair[0])
    target_tensor = tensorFromWord(output_lang, pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(batch_size):
    input_lang, output_lang, pairs = prepareData('spa', 'fre')

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromWord(input_lang, inp)
        tgt_ids = indexesFromWord(output_lang, tgt)
        
        inp_ids.append(EOW_token)
        tgt_ids.append(EOW_token)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    all_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    #train_sampler = RandomSampler(train_data)
    train_test_gen = torch.Generator().manual_seed(42)
    train_data, test_data = random_split(all_data, [0.8, 0.2], generator=train_test_gen)
    
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    
    test_sampler = RandomSampler(test_data)
    test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

    return input_lang, output_lang, train_dataloader, test_dataloader, test_data

In [253]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [254]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [255]:
def train(train_dataloader, test_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
               print_every=100, plot_every=100, patience=5):
    start = time.time()
    plot_losses = []
    print_loss_total = 0 
    plot_loss_total = 0 
    
    val_losses = []
    validation_list = []
    print_val_loss_total = 0
    plot_val_loss_total = 0

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss() #try cross entropy?

    for epoch in range(1, n_epochs + 1):
        start = time.time()
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss
        
        val_loss = validate(test_dataloader, encoder, decoder)

        validation_list.append(val_loss)
        
        print_val_loss_total += val_loss
        plot_val_loss_total += val_loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            
            print_val_loss_avg = print_val_loss_total / print_every
            print_val_loss_total = 0
            
            print('Epoch: {}/{},\tTime Taken: {:.2f} seconds,\tTraining Loss: {:.4f},\tValidation Loss: {:.4f}'.format(epoch, n_epochs, time.time()-start, print_loss_avg, print_val_loss_avg))
        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
            
            plot_val_loss_avg = plot_val_loss_total / plot_every
            val_losses.append(plot_val_loss_avg)
            plot_val_loss_total = 0

        if len(validation_list) > patience+1 and val_loss > max(validation_list[(-1 * patience) - 1:-1]):
            print("Validation loss has not gone down for " + str(patience) + " epochs. Implementing Early Stopping")
            break

    showPlot(plot_losses, 'train_loss')
    showPlot(val_losses, 'val_loss')

In [256]:
def validate(test_dataloader, encoder, decoder):
    encoder.eval()
    decoder.eval()
    
    criterion = nn.NLLLoss()
    
    total_loss = 0
    for data in test_dataloader:
    
        input_tensor, target_tensor = data
    
        encoder_outputs, encoder_hidden = encoder(input_tensor) 
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        
        total_loss += loss.item()
        
    encoder.train()
    decoder.train()
    
    return total_loss/len(test_dataloader)
    

In [257]:
def showPlot(points, loss_type):
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    if loss_type == 'val_loss':
        ax.plot(points, color='red')
        ax.set_title("Validation Loss")
    else:
        ax.plot(points)
        ax.set_title("Training Loss")
    plt.savefig(img_path + loss_type + '.png')
    plt.close(fig)

In [318]:
def evaluate(encoder, decoder, word, input_lang, output_lang):
    with torch.no_grad():
        #input_tensor = tensorFromWord(input_lang, word)
        input_tensor = word
        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_chars = []
        for idx in decoded_ids:
            if idx.item() == EOW_token:
                decoded_chars.append('<EOW>')
                break
            decoded_chars.append(output_lang.index2char[idx.item()])
        
    return decoded_chars, decoder_attn

In [377]:
def showAttention(input_word, output_chars, correct_output, attentions):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
    fig.colorbar(cax)

    ax.set_title(input_word + ' -> ' + correct_output)
    
    ax.set_xticklabels([''] + [*input_word] +
                       ['<EOW>'], rotation=90)
    ax.set_yticklabels([''] + output_chars)
    
    ax.set(xlim=(0, len(input_word)+1))
    
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.savefig(img_path+input_word+'_attention.png', bbox_inches='tight')
    plt.close(fig)

In [378]:
def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        random_index = random.randint(0, len(test_data))
        pair = test_data.dataset[random_index]
        pair = (torch.unsqueeze(pair[0], 0), torch.unsqueeze(pair[1], 0))
        
        output_chars, attentions = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        output_word = ''.join(output_chars)
        decoded_input = []
        
        in0=[]
        for idx in torch.squeeze(pair[0],0):
            if idx.item() == EOW_token:
                in0.append('<EOW>')
                break
            in0.append(input_lang.index2char[idx.item()])
            
        in1=[]
        for idx in torch.squeeze(pair[1],0):
            if idx.item() == EOW_token:
                in1.append('<EOW>')
                break
            in1.append(output_lang.index2char[idx.item()])
            
        input_word = ''.join(decoded_input)
        print('>', ''.join(in0))
        print('=', ''.join(in1))
        print('<', output_word)
        print('')
        showAttention(''.join(in0[0:-1]), output_chars, ''.join(in1[0:-1]), attentions[0, :len(output_chars), :])

In [259]:
hidden_size = 128
batch_size = 32

input_lang, output_lang, train_dataloader, test_dataloader, test_data = get_dataloader(batch_size)

encoder = EncoderRNN(input_lang.n_chars, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, output_lang.n_chars).to(device)

train(train_dataloader, test_dataloader, encoder, decoder, 40, print_every=1, plot_every=1, patience=5)

Reading lines...
Read 3369 sentence pairs
Counting chars...
Counted chars:
spa 31
fre 39
Epoch: 1/40,	Time Taken: 6.96 seconds,	Training Loss: 1.3591,	Validation Loss: 0.9789
Epoch: 2/40,	Time Taken: 6.09 seconds,	Training Loss: 0.9193,	Validation Loss: 0.8144
Epoch: 3/40,	Time Taken: 6.28 seconds,	Training Loss: 0.7593,	Validation Loss: 0.6431
Epoch: 4/40,	Time Taken: 6.62 seconds,	Training Loss: 0.5467,	Validation Loss: 0.4409
Epoch: 5/40,	Time Taken: 6.12 seconds,	Training Loss: 0.3767,	Validation Loss: 0.3178
Epoch: 6/40,	Time Taken: 6.80 seconds,	Training Loss: 0.2842,	Validation Loss: 0.2587
Epoch: 7/40,	Time Taken: 6.35 seconds,	Training Loss: 0.2319,	Validation Loss: 0.2248
Epoch: 8/40,	Time Taken: 7.25 seconds,	Training Loss: 0.2036,	Validation Loss: 0.1997
Epoch: 9/40,	Time Taken: 9.78 seconds,	Training Loss: 0.1749,	Validation Loss: 0.1816
Epoch: 10/40,	Time Taken: 10.77 seconds,	Training Loss: 0.1590,	Validation Loss: 0.1769
Epoch: 11/40,	Time Taken: 12.41 seconds,	Training

In [None]:
def checkpoint(model, filename):
    torch.save(model.state_dict(), filename)
    
def resume(model, filename):
    model.load_state_dict(torch.load(filename))

In [427]:
from torchmetrics.text import BLEUScore

def calculate_bleu_score(test_data):
    bleu_score_candidate = []
    bleu_score_reference = []
    for pair in test_data:
        pair = (torch.unsqueeze(pair[0], 0), torch.unsqueeze(pair[1], 0))
        
        output_chars, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        
        in1 = []
        for idx in torch.squeeze(pair[1], 0):
            if idx.item() == EOW_token:
                break
            in1.append(output_lang.index2char[idx.item()])

        ref_string = ''.join(in1)
        candidate_string = ''.join(output_chars[:-1])

        bleu_score_reference.append([ref_string])
        bleu_score_candidate.append(candidate_string)
        
    bleu = BLEUScore(n_gram=1)
    return bleu(bleu_score_candidate, bleu_score_reference)

print(calculate_bleu_score(test_data))

tensor(0.5364)


In [425]:
from tfa.metrics import F1Score

#Not working
def calculate_f1_score_char_level(test_data):
    all_preds = []
    all_targets = []

    for pair in test_data:
        input_seq, target_seq = torch.unsqueeze(pair[0], 0), torch.unsqueeze(pair[1], 0)

        output_chars, _ = evaluate(encoder, decoder, input_seq, input_lang, output_lang)
        
        encoded_preds = [output_lang.char2index[char] for char in output_chars[:-1]]  # Assuming output_chars is a list of characters
        encoded_targets = [output_lang.char2index[output_lang.index2char[idx.item()]] for idx in torch.squeeze(target_seq, 0) if (idx.item() != EOW_token) and idx.item() != SOW_token]

        all_preds.append(encoded_preds)
        all_targets.append(encoded_targets)

    preds_tensor = torch.nn.utils.rnn.pad_sequence([torch.tensor(p, dtype=torch.int64) for p in all_preds], batch_first=True, padding_value=-100)
    targets_tensor = torch.nn.utils.rnn.pad_sequence([torch.tensor(t, dtype=torch.int64) for t in all_targets], batch_first=True, padding_value=-100)

    #f1 = F1Score(num_classes=len(output_lang.char2index), average='macro', ignore_index=-100, task='multiclass')

    f1 = F1Score(num_classes=len(output_lang.char2index), average='marco', threshold=0.5)
    
    f1.update_state(preds_tensor, targets_tensor)

    return f1.result()

print(calculate_f1_score_char_level(test_data))

ModuleNotFoundError: No module named 'tfa'

In [380]:
encoder.eval()
decoder.eval()
evaluateRandomly(encoder, decoder)

> kulto<EOW>
= kylt<EOW>
< kylt<EOW>



  ax.set_xticklabels([''] + [*input_word] +
  ax.set_yticklabels([''] + output_chars)


> espektɾo<EOW>
= spɛktʁ<EOW>
< spɛktʁ<EOW>

> akusasiones<EOW>
= akyzasjɔ̃<EOW>
< akysasjɔ̃<EOW>

> aɾtifisialmente<EOW>
= aʁtifisjɛləmɑ̃<EOW>
< aʁtifisjɛləmɑ̃<EOW>

> korexiɾ<EOW>
= koʁiʒe<EOW>
< koʁiʒe<EOW>

> enkaɾnaɾ<EOW>
= ɛ̃kaʁne<EOW>
< ɑ̃kaʁne<EOW>

> eskaneɾ<EOW>
= skɑne<EOW>
< skɑne<EOW>

> baɾba<EOW>
= baʁb<EOW>
< vaʁb<EOW>

> kwotas<EOW>
= kota<EOW>
< kota<EOW>

> komunmente<EOW>
= komynemɑ̃<EOW>
< komynəmɑ̃<EOW>

