In [1]:
#Core Python, Pandas, and kaldi_io
import numpy as np
import pandas as pd
import string
from collections import Counter,OrderedDict 
import kaldi_io
from datetime import datetime

#ngrams
import nltk,re
import nltk.corpus
from nltk.corpus import switchboard
from nltk.util import ngrams

#Scikit
from sklearn import manifold
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import pairwise_distances,average_precision_score
from sklearn.metrics.pairwise import pairwise_kernels,paired_distances
from scipy import stats
from scipy.spatial.distance import pdist

#Plotting
from matplotlib import pyplot as plt
import seaborn as sns

#BigPhoney
from big_phoney import BigPhoney


#Torch and utilities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset,DataLoader,random_split,ConcatDataset

#Import User defined classes
from data_helpers import DataHelper
from models import SimpleNet
from train_test_helpers import accuracy,train_model,evaluate_model,evaluate_model_paper,test_model,plot_learning_curves
from sfba4.utils import alignSequences
from models import SimpleNet, SiameseNet
from siamese_dataset import SiameseTriplets
from train_test_helpers import accuracy,train_model,evaluate_model,evaluate_model_paper,test_model,plot_learning_curves

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################

Using TensorFlow backend.


In [2]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
#Load source model
source_net = SimpleNet(9974)
source_net = source_net.to(dev)
source_net_save_path = "./Models/test/awe_best_model.pth"
source_net.load_state_dict(torch.load(source_net_save_path))

<All keys matched successfully>

In [4]:
#word_embedding_dict = np.load('Data/word_embedding_dict.npy', allow_pickle = True)

In [5]:
#words = list(word_embedding_dict.item().keys())

In [6]:
num_examples = np.Inf
frequency_bounds = (0,155)
train_sm_dataset = SiameseTriplets(num_examples = num_examples, split_set = "train", frequency_bounds = frequency_bounds)
val_sm_dataset = SiameseTriplets(num_examples = num_examples, split_set = "val", frequency_bounds = frequency_bounds)
test_sm_dataset = SiameseTriplets(num_examples = num_examples, split_set = "test", frequency_bounds = frequency_bounds)

Length before filtering on char length 317927
Length after filtering on char length 173657
Length before filtering on frequency_bounds 173657
Length after filtering on frequency_bounds 125006
Finished Loading the Data, 125006 examples
Number of Unique words  9974
torch.Size([59844, 3, 40, 100])
Length before filtering on char length 317927
Length after filtering on char length 173657
Length before filtering on frequency_bounds 173657
Length after filtering on frequency_bounds 125006
Finished Loading the Data, 125006 examples
Number of Unique words  9974
torch.Size([19948, 3, 40, 100])
Length before filtering on char length 317927
Length after filtering on char length 173657
Length before filtering on frequency_bounds 173657
Length after filtering on frequency_bounds 125006
Finished Loading the Data, 125006 examples
Number of Unique words  9974
torch.Size([19948, 3, 40, 100])


In [7]:
train_dl = torch.utils.data.DataLoader(train_sm_dataset, shuffle = True, batch_size = 64, pin_memory = True)
val_dl = torch.utils.data.DataLoader(val_sm_dataset, shuffle = True, batch_size = 64, pin_memory = True)
test_dl = torch.utils.data.DataLoader(test_sm_dataset, shuffle = True, batch_size = 64, pin_memory = True)

In [8]:
word_to_num,num_to_word = train_sm_dataset.word_to_num,train_sm_dataset.num_to_word

In [9]:
def batch_letter_ngrams(words):
    letter_ngrams = []
    for word in words:
        letter_ngrams.append(give_letter_ngram(word))
    
    return np.stack(letter_ngrams)

In [12]:
def process_words(word):
    #Remove punctuation
    word = word.translate(str.maketrans('', '', string.punctuation))
    return "["+word.lower()+"]"

In [13]:
def give_common_ngrams(num = 50000):
    switchboard.ensure_loaded()
    words = switchboard.words()
    #Add start and end of word markers and make words lower case
    words = list(map(process_words,words))
    #Filter empty words
    words = list(filter(lambda x: x!="[]", words))

    #get all n_grams up to n=10
    n = 8
    ngrams_list = []

    for word in words:
        ngrams_list.append(list(filter(lambda x: x!=tuple('[') and x!= tuple(']'),list(ngrams(list(word),1)))))
        for i in range(2,n+1):
            ngrams_list.append(list(ngrams(list(word),i)))

    flatten = lambda l: [item for sublist in l for item in sublist]
    #Unroll the list
    ngrams_list = flatten(ngrams_list)

    ngrams_counter = Counter(ngrams_list)
    print(len(ngrams_counter.keys()))

    common_ngrams = []
    for index,(key,value) in enumerate(ngrams_counter.most_common(num)):
        common_ngrams.append(key)
    
    return common_ngrams



In [19]:
common_ngrams = give_common_ngrams(50000)

51794


In [20]:
common_ngrams

[('e',),
 ('t',),
 ('o',),
 ('a',),
 ('h',),
 ('i',),
 ('n',),
 ('s',),
 ('u',),
 ('r',),
 ('e', ']'),
 ('[', 't'),
 ('t', ']'),
 ('l',),
 ('d',),
 ('y',),
 ('t', 'h'),
 ('[', 't', 'h'),
 ('[', 'i'),
 ('w',),
 ('[', 'a'),
 ('s', ']'),
 ('m',),
 ('h', 'e'),
 ('d', ']'),
 ('g',),
 ('[', 's'),
 ('c',),
 ('h', ']'),
 ('i', 'n'),
 ('t', 'h', 'e'),
 ('[', 'w'),
 ('y', ']'),
 ('f',),
 ('a', 'n'),
 ('[', 't', 'h', 'e'),
 ('o', 'u'),
 ('n', ']'),
 ('h', 'a'),
 ('o', ']'),
 ('b',),
 ('[', 'o'),
 ('k',),
 ('r', 'e'),
 ('[', 'y'),
 ('p',),
 ('a', 't'),
 ('i', ']'),
 ('e', 'r'),
 ('[', 'i', ']'),
 ('n', 'd'),
 ('u', 'h'),
 ('[', 'u'),
 ('i', 't'),
 ('n', 'd', ']'),
 ('v',),
 ('r', ']'),
 ('[', 'a', 'n'),
 ('e', 'a'),
 ('a', 't', ']'),
 ('[', 'b'),
 ('a', 'n', 'd'),
 ('[', 'h'),
 ('a', 'n', 'd', ']'),
 ('h', 'e', ']'),
 ('h', 'a', 't'),
 ('[', 'm'),
 ('h', 'a', 't', ']'),
 ('[', 'a', 'n', 'd'),
 ('[', 'a', 'n', 'd', ']'),
 ('v', 'e'),
 ('u', 'h', ']'),
 ('y', 'o'),
 ('n', 'g'),
 ('[', 'u', 'h'),
 ('

In [21]:
#Map common ngrams to index values for one hot encoding
ngram_to_index = {}
#ngram_to_index
for index,ngram in enumerate(common_ngrams):
    ngram_to_index[ngram] = index

In [22]:
def give_letter_ngram(word):
    
    n=10
    word_list = list(word)
    letter_ngram = np.zeros(len(common_ngrams))
    
    #Extract ngrams from the word
    ngrams_list = []
    
    ngrams_list.append(list(filter(lambda x: x!=tuple('[') and x!= tuple(']'),list(ngrams(list(word),1)))))
    for i in range(2,n+1):
        ngrams_list.append(list(ngrams(list(word),i)))
    
    #Flatten
    flatten = lambda l: [item for sublist in l for item in sublist]
    #Unroll the list
    ngrams_list = flatten(ngrams_list)
    
    for ngram in ngrams_list:
        if ngram in ngram_to_index.keys():
            letter_ngram[ngram_to_index[ngram]] = 1
        
    return letter_ngram

In [23]:
def triplet_loss(word_embedding,same_word_embedding,diff_word_embedding,cos):
    m = 0.15
    lower_bound = torch.tensor(0.0).to(dev, non_blocking = True)
    a = torch.max(lower_bound,m - cos(word_embedding, same_word_embedding) - cos(word_embedding, diff_word_embedding))
    return torch.mean(a)

In [24]:
class OrthographicNet(nn.Module):
    def __init__(self,num_input,num_output):
        super(OrthographicNet, self).__init__()
        self.fc1 = nn.Linear(num_input, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, num_output)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        #print(x.shape)
        x = F.relu(self.fc2(x))
        #print(x.shape)
        x = F.relu(self.fc3(x))
        #print(x.shape)
        #print("Done")
        return x
    
    def give_embeddings(self,x,dev):
        x = F.relu(self.fc1(x))
        #print(x.shape)
        x = F.relu(self.fc2(x))
        #print(x.shape)
        x = F.relu(self.fc3(x))
        #print(x.shape)
        #print("Done")
        return x.cpu().detach().numpy() if dev.type == 'cuda' else x.detach().numpy()

In [27]:
num_input,num_output = 50000,9974
orthographic_net = OrthographicNet(num_input,num_output)
orthographic_net = orthographic_net.float()
orthographic_net.to(dev)
optimizer = optim.SGD(orthographic_net.parameters(), lr=0.001, momentum=0.9)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

In [None]:

num_epochs = 100
verbose = True
model_save_path = "./Models/best_orthographic_model2.pth"
best_val_loss = np.Inf

for epoch in range(0,num_epochs):
    if verbose:
            print('epoch %d '%(epoch))

    train_loss = 0
    orthographic_net.train()
    for batch_idx, (train_data,train_labels) in enumerate(train_dl):

        #print(train_data.shape)
        #Move to GPU
        optimizer.zero_grad()
        with torch.no_grad():
            train_data = train_data.to(dev, non_blocking=True)
            #Get word mfcc features
            word = train_data[:,0,:]
            #Get labels
            word_labels = [num_to_word[int(train_labels[i,0])] for i in range(train_labels.shape[0])]
            diff_word_labels = [num_to_word[int(train_labels[i,1])] for i in range(train_labels.shape[0])]
        #Get letter_ngrams
        word_letter_ngrams = torch.tensor(batch_letter_ngrams(word_labels), dtype =torch.float, device = dev)
        diff_letter_ngrams = torch.tensor(batch_letter_ngrams(diff_word_labels), dtype =torch.float,device = dev)
        
        #Get the word embedding and letter_ngram embeddings
        word_embedding = source_net(word)
        word_ngram_embedding = orthographic_net(word_letter_ngrams)
        diff_word_ngram_embedding = orthographic_net(diff_letter_ngrams)
        
        
        #Calculate the triplet loss
        loss = triplet_loss(word_embedding,word_ngram_embedding,diff_word_ngram_embedding, cos)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()


    orthographic_net.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_idx, (val_data,val_labels) in enumerate(val_dl):

            val_data = val_data.to(dev, non_blocking=True)
            #Get word mfcc features
            word = train_data[:,0,:]
            #Get labels
            word_labels = [num_to_word[int(val_labels[i,0])] for i in range(train_labels.shape[0])]
            diff_word_labels = [num_to_word[int(val_labels[i,1])] for i in range(train_labels.shape[0])]
            #Get letter_ngrams
            word_letter_ngrams = torch.tensor(batch_letter_ngrams(word_labels), dtype =torch.float, device = dev)
            diff_letter_ngrams = torch.tensor(batch_letter_ngrams(diff_word_labels), dtype =torch.float,device = dev)

            #Get the word embedding and letter_ngram embeddings
            word_embedding = source_net(word)
            word_ngram_embedding = orthographic_net(word_letter_ngrams)
            diff_word_ngram_embedding = orthographic_net(diff_letter_ngrams)
        
            

            #Calculate the triplet loss
            val_loss += triplet_loss(word_embedding,word_ngram_embedding,diff_word_ngram_embedding, cos)
            

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print("Best val loss %.3f Saving Model..."%(val_loss/len(val_dl)))
            torch.save(orthographic_net.state_dict(),model_save_path)


    if verbose:
        print("train loss: %.3f"%(train_loss/len(train_dl)))
        print("val loss: %.3f"%(val_loss/len(val_dl)))

epoch 0 
Best val loss 0.220 Saving Model...
train loss: 0.237
val loss: 0.220
epoch 1 
Best val loss 0.200 Saving Model...
train loss: 0.208
val loss: 0.200
epoch 2 
Best val loss 0.192 Saving Model...
train loss: 0.195
val loss: 0.192
epoch 3 
Best val loss 0.185 Saving Model...
train loss: 0.188
val loss: 0.185
epoch 4 
Best val loss 0.185 Saving Model...
train loss: 0.185
val loss: 0.185
epoch 5 
Best val loss 0.185 Saving Model...
train loss: 0.185
val loss: 0.185
epoch 6 
Best val loss 0.185 Saving Model...
train loss: 0.185
val loss: 0.185
epoch 7 
Best val loss 0.184 Saving Model...
train loss: 0.184
val loss: 0.184
epoch 8 
Best val loss 0.184 Saving Model...
train loss: 0.184
val loss: 0.184
epoch 9 
Best val loss 0.184 Saving Model...
train loss: 0.184
val loss: 0.184
epoch 10 
Best val loss 0.184 Saving Model...
train loss: 0.184
val loss: 0.184
epoch 11 
Best val loss 0.183 Saving Model...
train loss: 0.183
val loss: 0.183
epoch 12 
Best val loss 0.182 Saving Model...
trai