In [1]:
import torch
import torch.nn.functional as F
from heapq import nlargest

from src.model import CBOW, SkipGram
from src.constants import EMBEDDING_DIMS

In [2]:
def get_word_embedding(model, word, vocab):
    idx = vocab[word]
    with torch.no_grad():
        embedding = model.embeddings(torch.tensor(idx))
    return embedding.squeeze(0)


def find_most_similar_words(model, word, vocab, N=5):
    
    word_embedding = get_word_embedding(model, word, vocab)

    similarities = {}
    with torch.no_grad():
        for other_word in vocab.get_itos():
            if other_word == word:
                continue
            other_embedding = get_word_embedding(model, other_word, vocab)
            similarity = F.cosine_similarity(word_embedding.unsqueeze(0), other_embedding.unsqueeze(0)).item()
            similarities[other_word] = similarity

    # Get the top N most similar words
    most_similar = nlargest(N, similarities, key=similarities.get)

    return [(word, similarities[word]) for word in most_similar]


In [3]:
# loading the vocabulary
vocab = torch.load("./checkpoints/vocab.pt")

# loading the trained model
model_trained = SkipGram(vocab_size=len(vocab), dims=EMBEDDING_DIMS)
model_trained.load_state_dict(torch.load("./checkpoints/skipgram_epoch_10.pt"))
model_trained.eval()

# loading an untrained model, for comparison
model_untrained = SkipGram(vocab_size=len(vocab), dims=EMBEDDING_DIMS)

In [4]:
# get_word_embedding(model=model, word="tree", vocab=vocab)

In [11]:
# Example usage
# word = 'house'
# word = 'car'
# word = 'economics'
word = 'fruit'
N = 10

top_n_similar_words = find_most_similar_words(model_trained, word, vocab, N=N)

for w, sim in top_n_similar_words:
    print(w, sim)

flowers 0.6539772152900696
breed 0.6475479006767273
trees 0.625266969203949
sheep 0.5927416086196899
milk 0.5834458470344543
leaves 0.5726682543754578
insects 0.5712876319885254
vegetables 0.5630152225494385
fruits 0.5623157024383545
diet 0.5562630891799927


In [6]:
top_n_similar_words = find_most_similar_words(model_untrained, word, vocab, N=N)

for w, sim in top_n_similar_words:
    print(w, sim)

protesters 0.3924030661582947
mixes 0.3781982660293579
gulf 0.34597140550613403
emigrated 0.3322027325630188
metric 0.3302080035209656
audiences 0.32061439752578735
floating 0.3184584677219391
degrees 0.3180517256259918
former 0.3172435462474823
voters 0.31590351462364197
