In [38]:
import numpy as np
from gensim.models import KeyedVectors
from itertools import repeat
model_glove = KeyedVectors.load_word2vec_format('glove.6B.50d_word2vec.txt')

In [39]:
def find_range(point, range_limits):
    n_limits = len(range_limits)
    for limit in range(n_limits - 1):
        if (point > range_limits[limit]) and (point < range_limits[limit + 1]):
            return limit
    raise ValueError

In [75]:
def colapse_into_10(space_50d):
    ndim = space_50d.ndim
    if ndim == 1:  # If a word embedding
        space_10d = np.zeros(10)
        for idx in range(10):
            space_10d[idx] = np.sum(space_50d[idx*5:(idx+1)*5])
    else:
        space_10d = np.zeros([space_50d.shape[0], 10])
        for idx in range(10):
            space_10d[idx] = np.sum(space_50d[idx*5:(idx+1)*5],)
    return space_10d

In [93]:
def embspace_to_midi(word_embedding, n_words):
    """
    word_embedding: The 50dim vector resulting of difference between multiple words embedding
    n_words: Number of words used to create the word_embedding
    """
    embedding = np.load('mappings.npy')
    reduced_10 = colapse_into_10(embedding)
    maxs = np.max(reduced_10, axis=0)*n_words
    mins = np.min(reduced_10, axis=0)*n_words
    steps = (maxs - mins) / 129
    mappings = np.array(list((map(np.arange, mins, maxs, steps))))

    reduced_embedding = colapse_into_10(word_embedding)
    

    
    midi = np.zeros(10)
    for dimension in range(10):
        midi[dimension] = find_range(reduced_embedding[dimension], mappings[dimension])
    return midi

In [94]:
diffwords = model_glove['republican'] - model_glove['party']
dist1 = embspace_to_midi(diffwords, 2)

In [99]:
diffwords = model_glove['republican'] - model_glove['democratic']
dist2 = embspace_to_midi(diffwords, 2)

In [100]:
diffwords = model_glove['republican'] - model_glove['banana']
dist3 = embspace_to_midi(diffwords, 2)

In [103]:
diffwords = model_glove['republican'] - model_glove['avocado']
dist4 = embspace_to_midi(diffwords, 2)

In [104]:
print(np.sum((dist1)**2))
print(np.sum((dist2)**2))
print(np.sum((dist3)**2))
print(np.sum((dist4)**2))


7190.0
8105.0
5477.0
5260.0


In [105]:
print(dist1)
print(dist2)
print(dist3)
print(dist4)

[28. 27. 19. 23. 20. 31. 22. 22. 33. 37.]
[28. 27. 28. 22. 26. 28. 30. 28. 34. 32.]
[32. 18.  6. 23. 27. 12. 19. 23. 24. 35.]
[32. 22. 11. 25. 32. 10. 23. 13. 28. 20.]
