In [None]:
import pickle
import numpy as np
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

In [None]:
from matplotlib import rc

rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)

## Read in Embeddings and Word Dictionary

In [None]:
mus = np.load('/home/jonny/Documents/finalML/saved_checkpoints/test11/mu.npy')
sigmas = np.load('/home/jonny/Documents/finalML/saved_checkpoints/test11/sigma.npy')
word_dictionary = pickle.load(open('/home/jonny/Documents/finalML/saved_checkpoints/test11/word2id.pkl', 'rb'), encoding='latin1')
reversed_word_dictionary = {value: key for key, value in word_dictionary.items()}

In [None]:
mixtures = pickle.load(open('/home/jonny/Documents/finalML/saved_checkpoints/test11/mixture.pkl', 'rb'), encoding='latin1')

In [None]:
for x in mixtures:
    for i in mixtures[x]: 
        reversed_word_dictionary[i] = reversed_word_dictionary[x]

In [None]:
print('Shape of mus: {}'.format(mus.shape))
print('Shape of sigmas: {}'.format(sigmas.shape))
print('Vocabulary size: {}'.format(len(word_dictionary)))

In [None]:
len(reversed_word_dictionary.keys())

In [None]:
word_dictionary['star']

In [None]:
sigmas[592]

In [None]:
mixtures[801]

In [None]:
# Sanity check
assert not (np.isnan(mus).any() or np.isnan(sigmas).any())

## Compare Variance and Variance of Means

In [None]:
# Covariance of means
np.linalg.norm(np.cov(mus.T))

In [None]:
# Mean of covariances
np.linalg.norm(np.mean(sigmas, axis=0))

In [None]:
np.linalg.norm(np.cov(sigmas.T))

In [None]:
np.exp(0.0018488717822326764)

In [None]:
reversed_word_dictionary[np.argmax(sigmas)]


## kNN Analysis

In [None]:
knn = NearestNeighbors(n_neighbors=10).fit(mus)

In [None]:
# Look up word here call it x_not
idx = word_dictionary['rock']
embedding = mus[580].reshape(1,-1)
sigma = sigmas[idx]
print(sigma)
# print(embedding)
distances, indices = knn.kneighbors(embedding)
# print(indices)
for i in indices.flatten(): 
    try:
        print(reversed_word_dictionary[i])
    except:
        pass

In [None]:
def unique(array):
    tmp = set()
    count = 0
    for val in array :
          if val in tmp: 
                pass
          else: 
            count+=1
            tmp.add(val)
    return count

In [None]:
words = ['rock', 'bank', 'apple', 'star', 'cell', 'left', 'board', 'record', 'lie', 'chair', 'bar', 'lead']

for word in words:
    idx = word_dictionary[word]
    mixtures_idx = mixtures[idx]
    for i in range(2):
        indices = knn.kneighbors(mus[mixtures_idx[i]].reshape(1, -1), return_distance=False)
        foo = [np.exp(sigmas.flatten(order='F')[j]) for j in indices.flatten()]
        bar = [reversed_word_dictionary[j] for j in indices.flatten()]
        lst = [x for _, x in sorted(zip(foo[1:], bar[1:]))]
        print(bar[0] + '\t' + ' '.join(lst))

In [None]:
hist_array = []
for a in mixtures: 
    hist_array.append(unique(mixtures[a]))

In [None]:
plt.figure(figsize=(6, 4), dpi=300)
plt.hist(hist_array, bins=[0,1,2])
plt.ylim(0,700000)
ax = plt.gca()
ax.set_title('Histogram for Number of Learned Representations')
ax.set_xlabel('x')
h = ax.set_ylabel('Counts')
h.set_rotation(0)
ax.yaxis.set_label_coords(-0.025, 1.01)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
ax.set_xticks([0,1,2])

plt.tight_layout()
plt.savefig('hist.pdf', bbox_inches='tight')

In [None]:
plt.hist(hist_array, bins=[0,1,2])