In [35]:
import numpy as np
from scipy.spatial.distance import cosine


class SkipGram:
    def __init__(self, dictionary, u, v):
        self.dictionary = dictionary
        self.u = u
        self.v = v
        self.inverse_dictionary = {v: k for k, v in dictionary.items()}

    @classmethod
    def load(cls, dir):
        with open(dir + "/dictionary.txt") as fp:
            dictionary = {}

            for line in fp:
                key, value = line.strip().split()
                dictionary[key] = int(value)

        with open(dir + "/u.txt") as fp:
            u = []

            for line in fp:
                u.append(list(map(float, line.strip().split())))

            u = np.array(u)

        with open(dir + "/v.txt") as fp:
            v = []

            for line in fp:
                v.append(list(map(float, line.strip().split())))

            v = np.array(v)

        return cls(dictionary, u, v)
    
    def __getitem__(self, word):
        if word not in self.dictionary:
            raise ValueError(f"Word {word} not in dictionary")

        return self.v[self.dictionary[word]]
    
    def find_related_words(self, word):
        word_index = self.dictionary[word]
        similarities = np.dot(self.u, self.v[word_index])
        sorted_indices = np.argsort(similarities)[::-1]
        return [(self.inverse_dictionary[i], similarities[i]) for i in sorted_indices if i != word_index]

In [36]:
model = SkipGram.load("model")

In [37]:
print( cosine(model["quick"], model["fox"]) )
print( cosine(model["quick"], model["dog"]) )

0.7238526940419728
0.9924486935554684


In [38]:
print( cosine(model["sleeping"], model["floor"]) )
print( cosine(model["sleeping"], model["jumps"]) )

0.34288478902809705
1.3475802533550254


In [39]:
model.find_related_words("dog")

[('the', 12.247625667768599),
 ('lazy', 11.831040302156199),
 ('is', 11.047314786561099),
 ('sleeping', 10.630493579004801),
 ('on', 10.1893088855406),
 ('floor', 9.358005146392),
 ('over', 5.9063701767443995),
 ('jumps', 5.731913711576159),
 ('brown', 5.1802646088586),
 ('fox', 4.701875522564601),
 ('quick', 3.9465768108359804)]

In [40]:
model.find_related_words("fox")

[('jumps', -0.4788774300060623),
 ('over', -0.6746951559898003),
 ('quick', -0.7448703275659101),
 ('brown', -1.3987276450360007),
 ('sleeping', -3.3576652785411603),
 ('on', -3.464300799432),
 ('floor', -3.60176855121024),
 ('the', -3.8751410524206005),
 ('is', -4.7541096709045005),
 ('lazy', -5.08718442204498),
 ('dog', -5.41276050947304)]