In [1]:
import numpy as np


class CBOW:
    def __init__(self, dictionary, u, v):
        self.dictionary = dictionary
        self.u = u
        self.v = v
        self.inverse_dictionary = {v: k for k, v in dictionary.items()}

    @classmethod
    def load(cls, dir):
        with open(dir + "/dictionary.txt") as fp:
            dictionary = {}

            for line in fp:
                key, value = line.strip().split()
                dictionary[key] = int(value)

        with open(dir + "/u.txt") as fp:
            u = []

            for line in fp:
                u.append(list(map(float, line.strip().split())))

            u = np.array(u)

        with open(dir + "/v.txt") as fp:
            v = []

            for line in fp:
                v.append(list(map(float, line.strip().split())))

            v = np.array(v)

        return cls(dictionary, u, v)
    
    def __getitem__(self, word):
        if word not in self.dictionary:
            raise ValueError(f"Word {word} not in dictionary")

        return self.u[self.dictionary[word]]
    
    def find_related_words(self, word):
        word_index = self.dictionary[word]
        similarities = np.dot(self.u, self.v[word_index])
        sorted_indices = np.argsort(similarities)[::-1]
        return [(self.inverse_dictionary[i], similarities[i]) for i in sorted_indices if i != word_index]

In [2]:
model = CBOW.load("model")

In [3]:
model.find_related_words("dog")

[('the', 6.587040626582977),
 ('sleeping', 5.4066067605244),
 ('is', 5.2817044538012645),
 ('jumps', 5.187170890991),
 ('lazy', 4.9894404722677),
 ('fox', 4.7506221402266),
 ('floor', 4.748664496618041),
 ('on', 4.3032733195318),
 ('brown', 4.1029545888644),
 ('over', 3.9118347626515),
 ('quick', 3.4866698134417993)]

In [4]:
model.find_related_words("fox")

[('jumps', 5.368265724052849),
 ('quick', 5.295954805663001),
 ('dog', 4.882789764324801),
 ('over', 4.825041070876351),
 ('is', 4.366507799574699),
 ('brown', 4.1849944846755),
 ('sleeping', 4.0053416438623),
 ('on', 3.9605691600534496),
 ('the', 3.7995133680233497),
 ('floor', 2.96066613761285),
 ('lazy', 2.9203850405188003)]