In [4]:
import glob
from multiprocessing import Pool, cpu_count
import numpy as np
from smart_open import smart_open


def process_row(row):
    line = row.rstrip().split(' ')
    return line[0], line[1:]


def load_vocabulary(vocab_file):
    with smart_open(vocab_file, 'r+') as vocabulary, Pool(cpu_count()) as p:
        for word, _ in p.imap(process_row, vocabulary):
            yield word


def load_vectors(glove_vectors_file):
    with smart_open(glove_vectors_file, 'r+') as glove_vectors, Pool(cpu_count()) as p:
        vectors = {}
        for word, values in p.imap(process_row, glove_vectors):
            vectors[word] = np.float(values)
        return vectors


def laod_data_and_evaluate(vocab_file, glove_vectors_file, path=''):
    words = list(load_data(vocab_file))
    vectors = load_vectors(glove_vectors_file)

    vocab_size = len(words)
    vocab = {w: idx for idx, w in enumerate(words)}
    ivocab = {idx: w for idx, w in enumerate(words)}

    vector_dim = len(vectors[ivocab[0]])
    W = np.zeros((vocab_size, vector_dim))
    for word, v in vectors.iteritems():
        if word == '<unk>':
            continue
        W[vocab[word], :] = v

    # normalize each word vector to unit variance
    W_norm = np.zeros(W.shape)
    d = (np.sum(W ** 2, 1) ** (0.5))
    W_norm = (W.T / d).T
    evaluate_vectors(W_norm, vocab, ivocab, path)


def evaluate_vectors(W, vocab, ivocab, path):
    """Evaluate the trained word vectors on a variety of tasks"""

    # to avoid memory overflow, could be increased/decreased
    # depending on system and vocab size
    split_size = 100
    correct_sem = 0  # count correct semantic questions
    correct_syn = 0  # count correct syntactic questions
    correct_tot = 0  # count correct questions
    count_sem = 0  # count all semantic questions
    count_syn = 0  # count all syntactic questions
    count_tot = 0  # count all questions
    full_count = 0  # count all questions, including those with unknown words

    for filename in glob.glob('*.txt'.format(path)):
        with open(filename, 'r') as f:
            full_data = [line.rstrip().split(' ') for line in f]
            full_count += len(full_data)
            data = [x for x in full_data if all(word in vocab for word in x)]

        indices = np.array([[vocab[word] for word in row] for row in data])
        ind1, ind2, ind3, ind4 = indices.T

        predictions = np.zeros((len(indices),))
        num_iter = int(np.ceil(len(indices) / float(split_size)))
        for j in xrange(num_iter):
            subset = np.arange(j*split_size, min((j + 1)*split_size, len(ind1)))

            pred_vec = (W[ind2[subset], :] - W[ind1[subset], :] +
                        W[ind3[subset], :])
            # cosine similarity if input W has been normalized
            dist = np.dot(W, pred_vec.T)

            for k in xrange(len(subset)):
                dist[ind1[subset[k]], k] = -np.Inf
                dist[ind2[subset[k]], k] = -np.Inf
                dist[ind3[subset[k]], k] = -np.Inf

            # predicted word index
            predictions[subset] = np.argmax(dist, 0).flatten()

        val = (ind4 == predictions)  # correct predictions
        count_tot = count_tot + len(ind1)
        correct_tot = correct_tot + sum(val)
        if i < 5:
            count_sem = count_sem + len(ind1)
            correct_sem = correct_sem + sum(val)
        else:
            count_syn = count_syn + len(ind1)
            correct_syn = correct_syn + sum(val)

        print("%s:" % filename)
        print('ACCURACY TOP1: %.2f%% (%d/%d)' %
              (np.mean(val) * 100, np.sum(val), len(val)))

    print('Questions seen/total: %.2f%% (%d/%d)' %
          (100 * count_tot / float(full_count), count_tot, full_count))
    print('Semantic accuracy: %.2f%%  (%i/%i)' %
          (100 * correct_sem / float(count_sem), correct_sem, count_sem))
    print('Syntactic accuracy: %.2f%%  (%i/%i)' %
          (100 * correct_syn / float(count_syn), correct_syn, count_syn))
    print('Total accuracy: %.2f%%  (%i/%i)' %
          (100 * correct_tot / float(count_tot), correct_tot, count_tot))
