In [1]:
import math
import random

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class Vocab:
    def __init__(self, counter, sos, eos, pad, unk, min_freq=None):
        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.unk = unk
        
        self.pad_idx = 0
        self.unk_idx = 1
        self.sos_idx = 2
        self.eos_idx = 3
        
        self._token2idx = {
            self.sos: self.sos_idx,
            self.eos: self.eos_idx,
            self.pad: self.pad_idx,
            self.unk: self.unk_idx,
        }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.pad_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.pad)
    
    def __len__(self):
        return len(self._token2idx)
    
def padding(sequences, pad_idx):
    '''
    Inputs:
        sequences: list of list of tokens
    '''
    max_length = max(map(len, sequences))
    
    return [seq + [pad_idx]*(max_length - len(seq)) for seq in sequences]



import csv
from collections import Counter

def words_tokenize(line):
    return list(line)

def trans_tokenize(line):
    return line.split()

class Dataset(object):
    def __init__(self, path):
        val_size = 0.1
        shuffle  = True

        with open(path, 'r') as f:
            reader = csv.reader(f)
            lines   = list(reader)

        _, words, trans = zip(*lines[1:])

        c = list(zip(words, trans))
        random.shuffle(c)
        words, trans = zip(*c)

        val_size = int(len(words) * val_size)
        train_words, val_words = words[val_size:], words[:val_size]
        train_trans, val_trans = trans[val_size:], trans[:val_size]
        
        words_counter = Counter()
        trans_counter = Counter()

        for line in train_words:
            tokens = words_tokenize(line)
            for token in tokens:
                words_counter[token] += 1

        for line in train_trans:
            tokens = trans_tokenize(line)
            for token in tokens:
                trans_counter[token] += 1
                
        sos = "<sos>"
        eos = "<eos>"
        pad = "<pad>"
        unk = "<unk>"

        self.words_vocab = Vocab(words_counter, 
                            sos, eos, pad, unk)

        self.trans_vocab = Vocab(trans_counter, 
                            sos, eos, pad, unk)
        
        self.train_words = [[self.words_vocab.token2idx(item) for item in words_tokenize(word)] for word in train_words]
        self.val_words   = [[self.words_vocab.token2idx(item) for item in words_tokenize(word)] for word in val_words]

        self.train_trans = [[self.trans_vocab.token2idx(item) for item in trans_tokenize(trans)] for trans in train_trans]
        self.val_trans   = [[self.trans_vocab.token2idx(item) for item in trans_tokenize(trans)] for trans in val_trans]
        
    def __len__(self):
        return len(self.train_trans)
        
    def get_batch(self, batch_size, sort=False, val=False):
        if val:
            words, trans = self.val_words,   self.val_trans
        else:
            words, trans = self.train_words, self.train_trans

        random_ids = np.random.randint(0, len(words), batch_size)
        batch_words = [words[idx] for idx in random_ids]
        batch_trans = [trans[idx] for idx in random_ids]

        batch_trans_in  = [[self.trans_vocab.sos_idx] + tran for tran in batch_trans]
        batch_trans_out = [tran + [self.trans_vocab.eos_idx] for tran in batch_trans]

        words_lens = list(map(len, batch_words))
        trans_lens = list(map(len, batch_trans_in))

        batch_words     = padding(batch_words,     pad_idx=self.words_vocab.pad_idx)
        batch_trans_in  = padding(batch_trans_in,  pad_idx=self.trans_vocab.pad_idx)
        batch_trans_out = padding(batch_trans_out, pad_idx=self.trans_vocab.pad_idx)


        batch_words     = torch.LongTensor(batch_words)
        batch_trans_in  = torch.LongTensor(batch_trans_in)
        batch_trans_out = torch.LongTensor(batch_trans_out)
        words_lens = torch.LongTensor(words_lens)
        trans_lens = torch.LongTensor(trans_lens)

        if sort:
            lens, indices   = torch.sort(words_lens, 0, True)
            batch_words     = batch_words[indices]
            batch_trans_in  = batch_trans_in[indices]
            batch_trans_out = batch_trans_out[indices]
            trans_lens = trans_lens[indices]
            words_lens = lens

        return batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens

In [4]:
data = Dataset('data/transcriptions/train.csv')

In [5]:
batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = data.get_batch(32)

In [33]:
batch_words.size(), batch_trans_in.size(), batch_trans_out.size(), words_lens.size(), trans_lens.size()

(torch.Size([32, 12]),
 torch.Size([32, 12]),
 torch.Size([32, 12]),
 torch.Size([32]),
 torch.Size([32]))

In [34]:
batch_trans_out


tensor([[ 6, 15, 14, 10, 26, 12, 10, 21, 31,  3,  0,  0],
        [15, 18,  9, 31,  5, 30,  3,  0,  0,  0,  0,  0],
        [ 4, 37, 31,  5, 14,  3,  0,  0,  0,  0,  0,  0],
        [11, 20, 24, 33,  5, 21,  3,  0,  0,  0,  0,  0],
        [ 5, 22, 20, 24, 11,  3,  0,  0,  0,  0,  0,  0],
        [20, 36, 26,  8, 17,  3,  0,  0,  0,  0,  0,  0],
        [10, 28,  5,  4, 26, 22, 20,  7, 28, 15, 26,  3],
        [ 4, 20, 28, 34, 15,  8,  3,  0,  0,  0,  0,  0],
        [27, 14, 23, 14, 15, 17,  3,  0,  0,  0,  0,  0],
        [ 9, 28, 11, 10,  3,  0,  0,  0,  0,  0,  0,  0],
        [ 4, 24, 26, 10, 25, 24,  5, 39, 14,  7,  3,  0],
        [13, 12,  9,  5, 24, 35,  3,  0,  0,  0,  0,  0],
        [ 5, 37, 14, 15, 17,  3,  0,  0,  0,  0,  0,  0],
        [15, 22, 20, 30, 21, 31,  3,  0,  0,  0,  0,  0],
        [21, 32, 12,  5, 32,  3,  0,  0,  0,  0,  0,  0],
        [28, 20, 14,  7, 12, 20,  3,  0,  0,  0,  0,  0],
        [ 7, 14, 11, 15, 28, 20, 21,  3,  0,  0,  0,  0],
        [28,  

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence  as unpack

class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
        super(Encoder, self).__init__()
            
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.pad_idx = pad_idx
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)
        self.GRU = nn.GRU(emb_size, hidden_size, batch_first = True)
        
    def forward(self, batch_words):
        
        batch_size = batch_words.size(0)
        
        embedded = self.embedding(batch_words)
        _, hidden = self.GRU(embedded)

        return _, hidden
        
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
        super(Decoder, self).__init__()
         
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.pad_idx = pad_idx
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)
        self.GRU = nn.GRU(emb_size, hidden_size, batch_first = True)
        self.linear = nn.Linear(hidden_size, vocab_size)
       
    def forward(self, batch_trans_in, hidden):
        
        batch_size = batch_words.size(0)
        
        embedded = self.embedding(batch_trans_in)
        outputs, hidden = self.GRU(embedded, hidden)
        out = self.linear(outputs)
        out = out.view(-1, self.vocab_size)

        
        return out, hidden


# class Encoder(nn.Module):
#     def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
#         super(Encoder, self).__init__()
        
    
#     def forward(self, source, source_lens=None, hidden=None):
#         pass
    
# class Decoder(nn.Module):
#     def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
#         super(Decoder, self).__init__()
        
#     def forward(self, target, hidden):
#         pass
    
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
    def forward(self, batch_words, words_lens, batch_trans_in):
        _, hidden = self.encoder(batch_words)
        logits, _   = self.decoder(batch_trans_in, hidden)
        return logits
    
    def generate(self, bos_idx, eos_idx, batch_words):
        inp = [bos_idx]
        _, hidden = self.encoder(batch_words)

        for _ in range(100):
            inp_tensor = torch.LongTensor([[inp[-1]]]).to(batch_words.device)
            logits, hidden   = self.decoder(inp_tensor, hidden)
            next_token = F.softmax(logits, dim=-1)[-1].topk(1)[1].item()
            inp.append(next_token)
            if next_token == eos_idx:
                break
        return inp

In [9]:
def plot(epoch, batch_idx, train_losses, val_losses, train_lv, val_lv):
    clear_output(True)
    plt.figure(figsize=(20,6))
    plt.subplot(141)
    plt.title('epoch %s. | batch: %s | loss: %s' % (epoch, batch_idx, train_losses[-1]))
    plt.plot(train_losses)
    plt.subplot(142)
    plt.title('epoch %s. | batch: %s | LV distance on train: %s' % (epoch, batch_idx, train_lv[-1]))
    plt.plot(train_lv)
    plt.subplot(143)
    plt.title('epoch %s. | loss: %s' % (epoch, val_losses[-1]))
    plt.plot(val_losses)
    plt.subplot(144)
    plt.title('epoch %s. | LV distance on val: %s' % (epoch, val_lv[-1]))
    plt.plot(val_lv)
    
    plt.show()

In [10]:
import Levenshtein as lv

In [11]:
lv.distance

<function Levenshtein._levenshtein.distance>

In [37]:
def one_step(batch_size, val):
    batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = data.get_batch(batch_size, sort=True, val=val)

    batch_words     = batch_words.to(device)
    batch_trans_in  = batch_trans_in.to(device)
    batch_trans_out = batch_trans_out.to(device)
    words_lens      = words_lens.to(device)
    trans_lens      = trans_lens.to(device)
    
    logits = model(batch_words, words_lens, batch_trans_in)
    
    batch_trans_out = batch_trans_out.view(-1)
    mask = batch_trans_out != data.trans_vocab.pad_idx
    loss = criterion(logits[mask], batch_trans_out[mask])
    
    
    pred = logits[mask].detach().cpu().numpy()    
    
    pred = np.argmax(pred, axis=1)
    real = batch_trans_out[mask].detach().cpu().numpy()

    for 
    real = ' '.join(list(map(str, real)))
    pred = ' '.join(list(map(str, pred)))
    
    lv_dist = lv.distance(real, pred) 
    
    return loss, lv_dist

In [38]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
emb_size    = 32
hidden_size = 64
encoder = Encoder(data.words_vocab.vocab_size, emb_size, hidden_size, data.words_vocab.pad_idx).to(device)
decoder = Decoder(data.trans_vocab.vocab_size, emb_size, hidden_size, data.trans_vocab.pad_idx).to(device)
model   = Model().to(device)
model.encoder = encoder
model.decoder = decoder

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

batch_size = 32
epoch      = 0
num_epochs = 5

train_losses = []
val_losses   = []
train_lv = []
val_lv = []

In [39]:
while epoch < num_epochs:
    for batch_idx in range(len(data) // batch_size):
        loss, lv_dist = one_step(batch_size, val=False)
        break
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        train_lv.append(lv_dist)
        if batch_idx % 100 == 0:
            with torch.no_grad():
                loss, lv_dist = one_step(batch_size, val=True)
                val_losses.append(loss.item())
                val_lv.append(lv_dist)
            plot(epoch, batch_idx, train_losses, val_losses, train_lv, val_lv)
            
    epoch += 1

tensor([12, 12, 12, 12, 11, 10, 10, 10,  9,  9,  9,  9,  8,  8,  8,  8,  7,  7,
         7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  5,  5,  4,  3],
       device='cuda:0')
tensor([13, 10, 10, 10,  9,  9,  9,  9,  9,  8,  8,  8,  8,  7,  7,  7,  7,  7,
         7,  6,  6,  6,  6,  6,  5,  5,  5,  5,  4,  4,  3,  2],
       device='cuda:0')
tensor([12, 11, 11, 11, 11, 10, 10, 10, 10,  8,  8,  8,  8,  8,  8,  8,  8,  7,
         7,  7,  7,  7,  7,  7,  7,  6,  6,  6,  5,  4,  4,  3],
       device='cuda:0')
tensor([12, 12, 12, 11, 11, 10, 10, 10,  9,  9,  9,  9,  9,  8,  8,  8,  8,  8,
         7,  7,  7,  7,  7,  6,  6,  6,  6,  5,  5,  5,  4,  4],
       device='cuda:0')
tensor([15, 11, 11, 10,  9,  9,  9,  9,  8,  8,  8,  8,  8,  8,  7,  7,  7,  7,
         7,  7,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  4,  4],
       device='cuda:0')


In [2]:
def _print(val):
    batch_words, batch_trans_in, batch_trans_out, words_lens, trans_lens = data.get_batch(1, sort=True, val=val)
    batch_words     = batch_words.to(device)
    batch_trans_out = batch_trans_out.to(device)


    inp = model.generate(data.words_vocab.sos_idx, data.words_vocab.eos_idx, batch_words)
            
    tokens = [data.trans_vocab.idx2token(idx) for idx in inp if idx not in [data.trans_vocab.sos_idx,
                                                                       data.trans_vocab.eos_idx,
                                                                       data.trans_vocab.pad_idx]]
    print('Src: ', ''.join([data.words_vocab.idx2token(idx) for idx in batch_words[0].tolist()]))
    print('Pred:', ''.join(tokens))
    print('Real:', ''.join([data.trans_vocab.idx2token(idx) for idx in batch_trans_out[0].tolist() if idx not in [data.trans_vocab.sos_idx,
                                                                            data.trans_vocab.eos_idx,
                                                                            data.trans_vocab.pad_idx]]))

In [160]:
for _ in range(10):
    _print(True)
    print()

Src:  OLOUGHLIN
Pred: OWLUWGAHL
Real: OWLAWKLIHN

Src:  FIRSTFED
Pred: FERSTIHP
Real: FERSTFEHD

Src:  SEABROOKS
Pred: SIYBROWKS
Real: SIYBRUHKS

Src:  HEWELL
Pred: HHOYIHL
Real: HHYUWWEHL

Src:  KELVIN
Pred: KEHLVIHN
Real: KEHLVAHN

Src:  CUPERTINO
Pred: KAHPERIHSHAHN
Real: KUWPERTIYNOW

Src:  MASTIFS
Pred: MAESTIHFS
Real: MAESTAHFS

Src:  MOOSEHEAD
Pred: MOWSHAHD
Real: MUWSHHEHD

Src:  GAUZE
Pred: GUWZ
Real: GAOZ

Src:  METRE
Pred: MEHTER
Real: MIYTER

