In [77]:
import os
import pickle
import torch
import pandas as pd

from tqdm import tqdm
from sklearn.metrics import f1_score
from collections import Counter

text_counter = Counter()
tqdm.pandas()
mps_device = torch.device("mps")
torch.cuda.empty_cache()

In [88]:
WEIGHTS_PATH = './../weights/'
DATASET_PATH = './../data/processed/processed_data.csv'
INTERMEDIATE_PATH = './../intermediate/'
RESULTS_PATH = './../results/'
VALIDATION = './../data/Validation.txt'


embedding_size = 512
no_of_rows     = 30000
window_size    = 4
max_epochs     = 10
learning_rate  = 0.001
reg_lambda     = 1
batch_size     = 2048
train_split    = 0.8
is_train       = True
new_weights    = True
use_biases     = False
problem        = 'cbow'

epochs_stat = []

In [89]:
class Net(object):
    
    def __init__(self):
        self.y_pred = None
        self.emb    = None

        if(os.path.exists(WEIGHTS_PATH + 'cbow_weights_' + str(vocab_size) + '.pt') and (new_weights == False)):
            self.weights = torch.load(WEIGHTS_PATH + 'cbow_weights_' + str(vocab_size) + '.pt')
            self.biases = torch.load(WEIGHTS_PATH + 'cbow_biases_' + str(vocab_size) + '.pt')
        else:
            self.weights = []
            self.biases  = []
            
            self.weights.append(torch.rand((vocab_size, embedding_size), device=mps_device) * 2 - 1)
            self.weights.append(torch.rand((embedding_size, vocab_size), device=mps_device) * 2 - 1)
            if(use_biases):
                self.biases.append(torch.rand((embedding_size), device=mps_device) * 2 - 1)
                self.biases.append(torch.rand((vocab_size), device=mps_device) * 2 - 1)
            else:
                self.biases.append(torch.zeros((embedding_size), device=mps_device))
                self.biases.append(torch.zeros((vocab_size), device=mps_device))

    def __call__(self, X):
        self.emb = (torch.matmul(X,self.weights[0]) + self.biases[0])
        self.y_pred = torch.softmax(torch.matmul(self.emb, self.weights[1]) + self.biases[1], dim=1)
        return self.y_pred

    def backward(self, X, y, lamda):
        
        del_W = []
        del_b = []

        delta = self.y_pred - y
        del_b.insert(0,torch.sum(delta, axis = 0, keepdims = True))
        del_W.insert(0,torch.matmul(self.emb.T, delta) + lamda * (self.weights[1]))

        delta = torch.matmul(delta, self.weights[1].T) * (self.emb)
        del_b.insert(0,torch.sum(delta, axis = 0, keepdims = True))
        del_W.insert(0,torch.matmul(X.T, delta) + lamda * (self.weights[0]))
        return del_W, del_b

    
class Optimizer(object):
    '''
    '''
    def __init__(self, learning_rate, weights, biases, optimizer="gradient"):
        
        
        self.optimizer = optimizer
        
        self.m_dw = [torch.zeros((w.shape), device=mps_device) for w in weights]
        self.m_db = [torch.zeros((b.shape), device=mps_device) for b in biases]
        self.v_dw = [torch.zeros((w.shape), device=mps_device) for w in weights]
        self.v_db = [torch.zeros((b.shape), device=mps_device) for b in biases]
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.epsilon = 1e-8
        self.eta = learning_rate
        self.t = 0

    def step(self, weights, biases, delta_weights, delta_biases):
        
        if(self.optimizer == "gradient"):
            return self.gradient(weights, biases, delta_weights, delta_biases)
        elif(self.optimizer == "adam"):
            return self.adam(weights, biases, delta_weights, delta_biases)
    
    def adam(self, weights, biases, delta_weights, delta_biases):
        self.t += 1

        self.m_dw = [self.beta1 * m + (1 - self.beta1) * del_w for m, del_w in zip(self.m_dw, delta_weights)]
        self.m_db = [self.beta1 * m + (1 - self.beta1) * del_b for m, del_b in zip(self.m_db, delta_biases)]
        self.v_dw = [self.beta2 * v + (1 - self.beta2) * (del_w**2) for v, del_w in zip(self.v_dw, delta_weights)]
        self.v_db = [self.beta2 * v + (1 - self.beta2) * (del_b**2) for v, del_b in zip(self.v_db, delta_biases)]

        # bias correction
        m_hat_dw = [m / (1 - self.beta1 ** self.t) for m in self.m_dw]
        v_hat_dw = [v / (1 - self.beta2 ** self.t) for v in self.v_dw]

        m_hat_db = [m / (1 - self.beta1 ** self.t) for m in self.m_db]
        v_hat_db = [v / (1 - self.beta2 ** self.t) for v in self.v_db]

        # update weights and biases
        weights = [w - self.eta * m_hat / ((torch.sqrt(v_hat) + self.epsilon)) for w, m_hat, v_hat in zip(weights, m_hat_dw, v_hat_dw)] 
        biases = [b - self.eta * m_hat / ((torch.sqrt(v_hat) + self.epsilon)) for b, m_hat, v_hat in zip(biases, m_hat_db, v_hat_db)]
        return weights, biases

    
    def gradient(self, weights, biases, delta_weights, delta_biases):
        for i in range(len(weights)):
            weights[i] = weights[i] - self.eta*delta_weights[i]
            biases[i] = biases[i] - self.eta*delta_biases[i]
        return weights, biases

In [90]:
def get_embeddings(word):
    try:
        val = word2idx[word]
        input_arr = torch.zeros((vocab_size), dtype=torch.float, device=mps_device)
        input_arr[val] = 1
        emb = torch.matmul(input_arr,net.weights[0]) + net.biases[0]
        return emb
    except:
        return torch.empty(0, dtype=torch.float, device=mps_device)

def get_result(word1, word2, word3, problem='skipgram', window=4):
    w1_emb = get_embeddings(word1)
    w2_emb = get_embeddings(word2)    
    w3_emb = get_embeddings(word3)    
    if((torch.numel(w1_emb) != 0) and (torch.numel(w2_emb) != 0) and (torch.numel(w3_emb) != 0)):
        w4_emb = w1_emb + w3_emb - w2_emb
        output = torch.softmax(torch.matmul(w4_emb, net.weights[1]) + net.biases[1], dim=1)
        if(problem=='skipgram'):
            topk_preds = torch.topk(output, k=window).indices.tolist()
            ans = [idx2word[x] for x in topk_preds[0]]
        else:
            ans = idx2word[torch.argmax(output).item()]
        return ans
    return "UNK"


def get_embeddings(word):
    try:
        val = word2idx[word]
        input_arr = torch.zeros((vocab_size), dtype=torch.float, device=mps_device)
        input_arr[val] = 1
        emb = torch.matmul(input_arr,net.weights[0]) + net.biases[0]
        return emb
    except:
        return torch.empty(0, dtype=torch.float, device=mps_device)
    
def get_accuracy(validation):
    accuracy = torch.empty(0, dtype=torch.float, device=mps_device)
    for _, row in validation.iterrows():
        w1_emb = get_embeddings(row['word1'])
        w2_emb = get_embeddings(row['word2'])    
        w3_emb = get_embeddings(row['word3'])
        w4_emb = get_embeddings(row['word4'])
        
        if((torch.numel(w1_emb) != 0) and (torch.numel(w2_emb) != 0) and (torch.numel(w3_emb) != 0)):
            pred_emb = w1_emb + w3_emb - w2_emb
            accuracy = torch.cat(accuracy, cosine_similarity(pred_emb, w4_emb), dim=0)
    return torch.mean(accuracy, dim=0)

def cosine_similarity(A, B):
    C = torch.squeeze(A)
    D = torch.squeeze(B)
    return torch.dot(C, D) / (torch.norm(C) * torch.norm(D))

def cross_entropy_loss(y_pred, y_true):
    if(mps_device.type == 'mps'):
        return -torch.mean(custom_nansum(y_true * torch.log(y_pred), axis=-1))
    else:
        return -torch.mean(torch.nansum(y_true * torch.log(y_pred), axis=-1))

def get_dataset(words, window_size):
    data = pd.DataFrame(columns=["word", "context_words"])
    for index, word in enumerate(words):
        context_words = words[max(0, index - window_size): index] + words[index + 1: index + window_size + 1]
        context_words = list(set(context_words))
        data.loc[len(data.index)] = [word, context_words]
    global dataset
    dataset = pd.concat([dataset,data])
    
def custom_nansum(input_tensor, axis=None):
    # Replace NaN values with 0
    input_tensor = torch.where(torch.isnan(input_tensor), torch.tensor(0., device=input_tensor.device), input_tensor)
    
    # Compute the sum over the specified axis
    if axis is None:
        return torch.sum(input_tensor)
    else:
        return torch.sum(input_tensor, dim=axis)
    
    
def get_batch_data(start_index, end_index):

    batch_dataset = dataset[start_index:end_index].reset_index(drop=True)
    batch_size = batch_dataset.shape[0]
    batch_input = torch.zeros([batch_size, vocab_size], dtype=torch.float, device=mps_device)
    batch_output = torch.zeros([batch_size, vocab_size], dtype=torch.float, device=mps_device)

    for index, data in batch_dataset.iterrows():
        batch_input[index, data['word']] = 1
        for ind in data['context_words']:
            batch_output[index, ind] = 1
            
    return batch_output, batch_input


def train(net, optimizer, lamda, max_epochs, dev_input, dev_target, batch_size, train_size):

    stop = False
    value = 999999999
    
    for e in range(max_epochs):
        first_loss = True
        epoch = {}
        batches = []
        losses = []
        for start_index in range(0, train_size, batch_size):
            end_index = min(start_index + batch_size, train_size)
            batch_target, batch_input = get_batch_data(start_index, end_index)
            pred = net(batch_input)

            # Compute gradients of loss w.r.t. weights and biases
            dW, db = net.backward(batch_input, batch_target, lamda)

            # Get updated weights based on current weights and gradients
            weights_updated, biases_updated = optimizer.step(net.weights, net.biases, dW, db)

            # Update model's weights and biases
            net.weights = weights_updated
            net.biases = biases_updated
            loss = cross_entropy_loss(pred, batch_target)
            print(e, start_index, loss.item())
            batches.append(start_index)
            losses.append(loss.item())
            if(torch.isnan(loss) or torch.isinf(loss)):
                stop = True
                break

            if(first_loss):
                first_loss = False
                if(value<loss.item()):
                    stop = True
                    break
                else:
                    value = loss.item()

        epoch['batches'] = batches
        epoch['losses'] = losses
            

        dev_pred = net(dev_input)
        indices = torch.topk(dev_pred, k=window_size, dim=1)[1][:, -window_size:]
        converted_matrix = torch.zeros_like(dev_pred)
        converted_matrix[torch.arange(dev_pred.shape[0])[:, None], indices] = 1
        numpy_dev_target = dev_target.cpu().numpy()
        converted_matrix = converted_matrix.cpu().numpy()
        score = f1_score(numpy_dev_target, converted_matrix, average='micro')
        print('F1 Score on dev data: {:.5f}'.format(score))
        epoch['f1_score'] = score
        epochs_stat.append(epoch)
        if(stop):
            break

### Loading Data

In [91]:
df = pd.read_csv(DATASET_PATH ,nrows = no_of_rows)
df = df.dropna()

df['sentences'] = df['sentences'].apply(lambda sentence : sentence.split())
_ = df['sentences'].apply(text_counter.update)
text_counter.update(['UNK'])
vocab_size = len(text_counter)

print(sum(text_counter.values()))

words, _ = zip(*text_counter.most_common(vocab_size))
word2idx = {w: i for i, w in enumerate(words)}
idx2word = {i: w for i, w in enumerate(words)}


# dataset = pd.read_pickle('dataset.pkl')
df['sentences'] = df['sentences'].apply(lambda words : [word2idx[word] for word in words])

if(no_of_rows==None and os.path.exists(INTERMEDIATE_PATH + 'dataset.pkl')):
   dataset = pd.read_pickle(INTERMEDIATE_PATH + 'dataset.pkl') 

elif(os.path.exists(INTERMEDIATE_PATH + 'dataset_'+str(no_of_rows)+'.pkl')):
    dataset = pd.read_pickle(INTERMEDIATE_PATH + 'dataset_'+str(no_of_rows)+'.pkl')
else:
    dataset = pd.DataFrame(columns=["word", "context_words"])
    _ = df['sentences'].progress_apply(lambda x : get_dataset(x, window_size))

    dataset = dataset.sample(frac=1).reset_index(drop=True)
    if(no_of_rows==None):
        dataset.to_pickle(INTERMEDIATE_PATH + 'dataset.pkl') 
    else:
        dataset.to_pickle(INTERMEDIATE_PATH + 'dataset_'+str(no_of_rows)+'.pkl')

data_size = dataset.shape[0]

print("Vocab length :", vocab_size)
print("Dataset size :", data_size)

607836
Vocab length : 40107
Dataset size : 216522


### Training the model

In [92]:
train_size = max(int(0.90*data_size), data_size-5000)

net = Net()
optimizer = Optimizer(learning_rate, net.weights, net.biases, optimizer="adam")

In [93]:
if(is_train):
    dev_target, dev_input = get_batch_data(train_size, data_size)
    train(net, optimizer, reg_lambda, max_epochs, dev_input, dev_target, batch_size, train_size)

    torch.save(net.weights, WEIGHTS_PATH + 'cbow_weights_' + str(vocab_size) + '.pt')
    torch.save(net.biases, WEIGHTS_PATH + 'cbow_biases_' + str(vocab_size) + '.pt')
else:
    net.weights = torch.load(WEIGHTS_PATH + 'cbow_weights_' + str(vocab_size) + '.pt')
    net.biases = torch.load(WEIGHTS_PATH + 'cbow_biases_' + str(vocab_size) + '.pt')

0 0 139.7922821044922
0 2048 136.05772399902344
0 4096 137.91940307617188
0 6144 135.83111572265625
0 8192 135.20233154296875
0 10240 136.98550415039062
0 12288 135.4839324951172
0 14336 133.4728240966797
0 16384 133.40158081054688
0 18432 133.43368530273438
0 20480 131.66600036621094
0 22528 131.90745544433594
0 24576 131.57225036621094
0 26624 130.16317749023438
0 28672 131.12863159179688
0 30720 129.32839965820312
0 32768 131.1597137451172
0 34816 129.59036254882812
0 36864 127.76079559326172
0 38912 128.76026916503906
0 40960 127.50633239746094
0 43008 126.96383666992188
0 45056 129.4069061279297
0 47104 126.49707794189453
0 49152 127.63768005371094
0 51200 126.14712524414062
0 53248 126.30921173095703
0 55296 125.26234436035156
0 57344 127.58151245117188
0 59392 125.19808959960938
0 61440 122.41328430175781
0 63488 122.55841064453125
0 65536 125.22294616699219
0 67584 123.24613189697266
0 69632 125.43917083740234
0 71680 123.994140625
0 73728 124.2583236694336
0 75776 123.20292663

2 202752 71.62602996826172
2 204800 72.22657012939453
2 206848 71.74776458740234
2 208896 70.39109802246094
2 210944 72.93453979492188
F1 Score on dev data: 0.01896
3 0 72.57882690429688
3 2048 71.08993530273438
3 4096 71.66047668457031
3 6144 71.21463775634766
3 8192 71.3689956665039
3 10240 72.88203430175781
3 12288 72.42891693115234
3 14336 70.98551177978516
3 16384 70.75045013427734
3 18432 71.65196990966797
3 20480 69.95182800292969
3 22528 70.98160552978516
3 24576 70.85200500488281
3 26624 70.00276184082031
3 28672 70.70428466796875
3 30720 69.74179077148438
3 32768 71.02493286132812
3 34816 70.06451416015625
3 36864 69.51483154296875
3 38912 70.10832214355469
3 40960 69.49256134033203
3 43008 69.57696533203125
3 45056 70.63912963867188
3 47104 69.62510681152344
3 49152 69.93360900878906
3 51200 69.13990783691406
3 53248 69.66243743896484
3 55296 69.40829467773438
3 57344 70.31852722167969
3 59392 69.70999145507812
3 61440 67.93171691894531
3 63488 68.24513244628906
3 65536 69.8

5 190464 46.705055236816406
5 192512 46.58042907714844
5 194560 46.23357391357422
5 196608 46.36640167236328
5 198656 45.82695770263672
5 200704 46.85900115966797
5 202752 45.92211151123047
5 204800 46.372703552246094
5 206848 46.356903076171875
5 208896 45.49235534667969
5 210944 46.957176208496094
F1 Score on dev data: 0.02911
6 0 inf
F1 Score on dev data: 0.02911


### Validation

In [84]:
validation = pd.read_csv(VALIDATION, sep=' ', names=['word1','word2','word3','word4'])

validation['word1'] = validation['word1'].apply(lambda x : x.lower())
validation['word2'] = validation['word2'].apply(lambda x : x.lower())
validation['word3'] = validation['word3'].apply(lambda x : x.lower())
validation['word4'] = validation['word4'].apply(lambda x : x.lower())

validation['pred'] = validation[['word1','word2','word3']].progress_apply(lambda x : get_result(x['word1'], x['word2'], x['word3'], problem='cbow'), axis = 1)
validation['found'] = validation[['word4','pred']].progress_apply(lambda x : x['word4'] in x['pred'], axis=1)

count = len(os.listdir(RESULTS_PATH))
validation.to_csv(RESULTS_PATH + "value_"+str(count)+'_'+str(vocab_size)+'_'+str(no_of_rows)+problem+'.csv',index=False)

100%|████████████████████████████████████████| 991/991 [00:05<00:00, 171.00it/s]
100%|█████████████████████████████████████| 991/991 [00:00<00:00, 126385.16it/s]


In [85]:
validation[validation['pred']==validation['word4']]

Unnamed: 0,word1,word2,word3,word4,pred,found


In [86]:
validation

Unnamed: 0,word1,word2,word3,word4,pred,found
0,walk,walks,see,sees,became,False
1,walk,walks,shuffle,shuffles,became,False
2,walk,walks,sing,sings,august,False
3,walk,walks,sit,sits,august,False
4,walk,walks,slow,slows,august,False
...,...,...,...,...,...,...
986,argentina,peso,nigeria,naira,archived,False
987,argentina,peso,iran,rial,population,False
988,argentina,peso,japan,yen,archived,False
989,india,rupee,iran,rial,retrieved,False


In [95]:
get_result('jan','january','dec' , problem='cbow')

# get_similarity('queen', 'woman')

'government'

In [72]:
validation = pd.read_csv(VALIDATION, sep=' ', names=['word1','word2','word3','word4'])

validation['word1'] = validation['word1'].apply(lambda x : x.lower())
validation['word2'] = validation['word2'].apply(lambda x : x.lower())
validation['word3'] = validation['word3'].apply(lambda x : x.lower())
validation['word4'] = validation['word4'].apply(lambda x : x.lower())

validation['pred'] = validation[['word1','word2','word3']].progress_apply(lambda x : get_result(x['word1'], x['word2'], x['word3'], problem='skipgram'), axis = 1)
validation['found'] = validation[['word4','pred']].progress_apply(lambda x : x['word4'] in x['pred'], axis=1)

count = len(os.listdir(RESULTS_PATH))
validation.to_csv(RESULTS_PATH + "value_"+str(count)+'_'+str(vocab_size)+'_'+str(no_of_rows)+problem+'.csv',index=False)

100%|████████████████████████████████████████| 991/991 [00:04<00:00, 199.44it/s]
100%|█████████████████████████████████████| 991/991 [00:00<00:00, 207102.90it/s]


In [73]:
validation[validation['found']==True]

Unnamed: 0,word1,word2,word3,word4,pred,found


In [74]:
validation

Unnamed: 0,word1,word2,word3,word4,pred,found
0,walk,walks,see,sees,"[india, largest, city, system]",False
1,walk,walks,shuffle,shuffles,"[largest, city, one, budapest]",False
2,walk,walks,sing,sings,"[largest, budapest, city, portugal]",False
3,walk,walks,sit,sits,"[largest, states, city, europe]",False
4,walk,walks,slow,slows,"[largest, many, also, country]",False
...,...,...,...,...,...,...
986,argentina,peso,nigeria,naira,"[among, india, world, state]",False
987,argentina,peso,iran,rial,"[city, india, world, national]",False
988,argentina,peso,japan,yen,"[national, north, world, india]",False
989,india,rupee,iran,rial,"[city, new, one, area]",False
