In [1]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

import gensim
import numpy as np
from tqdm import tqdm_notebook as tqdm
import pickle

In [2]:
### Pretrained word2vec
# Google’s pre-trained Word2Vec (1.5GB), word vectors for a vocabulary of 3 million words 
# and phrases that they trained on roughly 100 billion words from a Google News dataset
# https://code.google.com/archive/p/word2vec/
model = gensim.models.KeyedVectors.load_word2vec_format('./word2vec_pretrained/GoogleNews-vectors-negative300.bin', binary=True)  

In [3]:
# Check vocab size
words = [w for w in model.key_to_index]
print("Vocab size: ", len(words))

Vocab size:  3000000


In [4]:
def find_n_closest_for_analogy(filename, n):
    word_list_file = open(filename).readlines()
    analogies = [line.strip().split() for line in word_list_file if line.strip() != ""]
    
    closest_words_list = [None for i in range(len(analogies))]
    answer = [item[3] for item in analogies]
    
    for idx, analogy in enumerate(analogies):
        top3_similar_words = model.most_similar(positive=[analogy[1], analogy[2]], 
                                                negative=[analogy[0]], topn=n)
        closest_words_list[idx] = [w for (w, _) in top3_similar_words]
    
    return closest_words_list, answer 

In [5]:
def evaluate_accuracy(closest_found, correct, n=3):
    top1 = np.array([c[0] for c in closest_found]) == np.array(correct)
    top3 = top1
    for i in range(1, n):
        top3 = np.logical_or(top3, np.array([c[i] for c in closest_found]) == np.array(correct))
    return np.sum(top1)/len(correct), np.sum(top3)/len(correct)

In [6]:
def predict_analogies_for_list(filename):
    print('Testing file: ', filename)
    closest_found, correct = find_n_closest_for_analogy(filename, 3)
    top1_acc, top3_acc = evaluate_accuracy(closest_found, correct)
    print("     Top1 accuracy:", top1_acc * 100, "%")
    print("     Top3 accuracy:", top3_acc * 100, "%")
    print()

In [7]:
for i in range(5):
    predict_analogies_for_list(('word_lists/list%d.txt' % (i+1)))

Testing file:  word_lists/list1.txt
     Top1 accuracy: 85.0 %
     Top3 accuracy: 90.0 %

Testing file:  word_lists/list2.txt
     Top1 accuracy: 85.0 %
     Top3 accuracy: 90.0 %

Testing file:  word_lists/list3.txt
     Top1 accuracy: 35.0 %
     Top3 accuracy: 45.0 %

Testing file:  word_lists/list4.txt
     Top1 accuracy: 20.0 %
     Top3 accuracy: 25.0 %

Testing file:  word_lists/list5.txt
     Top1 accuracy: 60.0 %
     Top3 accuracy: 75.0 %

