In [1]:
import torch
import torch.nn.functional as F
from heapq import nlargest

from src.model import CBOW
from src.constants import EMBEDDING_DIMS

In [12]:
def get_word_embedding(model, word, vocab):
    idx = vocab[word]
    with torch.no_grad():
        embedding = model.embeddings(torch.tensor(idx))
    return embedding.squeeze(0)


def find_most_similar_words(model, word, vocab, N=5):
    
    word_embedding = get_word_embedding(model, word, vocab)

    similarities = {}
    with torch.no_grad():
        for other_word in vocab.get_itos():
            other_embedding = get_word_embedding(model, other_word, vocab)
            similarity = F.cosine_similarity(word_embedding.unsqueeze(0), other_embedding.unsqueeze(0)).item()
            similarities[other_word] = similarity

    # Get the top N most similar words
    most_similar = nlargest(N, similarities, key=similarities.get)

    return [(word, similarities[word]) for word in most_similar]


In [13]:
# loading the vocabulary
vocab = torch.load("./checkpoints/vocab.pt")

# loading the model
model = CBOW(vocab_size=len(vocab), dims=EMBEDDING_DIMS)
model.load_state_dict(torch.load("./checkpoints/model_epoch_5.pt"))
model.eval()

CBOW(
  (embeddings): Embedding(10001, 300)
  (linear): Linear(in_features=300, out_features=10001, bias=True)
)

In [14]:
# get_word_embedding(model=model, word="tree", vocab=vocab)

In [22]:
# Example usage
word = 'house'
N = 10

top_n_similar_words = find_most_similar_words(model, word, vocab, N=N)

for w, sim in top_n_similar_words:
    print(w, sim)

house 0.9999998807907104
building 0.22196169197559357
preparatory 0.19675834476947784
shock 0.19357812404632568
sir 0.1928451806306839
theodore 0.18704940378665924
cathedral 0.18146641552448273
theater 0.1805182844400406
gas 0.17933134734630585
abbey 0.17814625799655914
