In [1]:
import torch
import torch.nn.functional as F
from heapq import nlargest

from src.model import CBOW, SkipGram
from src.constants import EMBEDDING_DIMS

In [2]:
def get_word_embedding(model, word, vocab):
    idx = vocab[word]
    with torch.no_grad():
        embedding = model.embeddings(torch.tensor(idx))
    return embedding.squeeze(0)


def find_most_similar_words(model, word, vocab, N=5):
    
    word_embedding = get_word_embedding(model, word, vocab)

    similarities = {}
    with torch.no_grad():
        for other_word in vocab.get_itos():
            if other_word == word:
                continue
            other_embedding = get_word_embedding(model, other_word, vocab)
            similarity = F.cosine_similarity(word_embedding.unsqueeze(0), other_embedding.unsqueeze(0)).item()
            similarities[other_word] = similarity

    # Get the top N most similar words
    most_similar = nlargest(N, similarities, key=similarities.get)

    return [(word, similarities[word]) for word in most_similar]


In [4]:
# loading cbow model
cbow_vocab = torch.load("./checkpoints/cbow_vocab.pt")
cbow_model_trained = CBOW(vocab_size=len(cbow_vocab), dims=EMBEDDING_DIMS)
cbow_model_trained.load_state_dict(torch.load("./checkpoints/cbow_epoch_10.pt"))
cbow_model_trained.eval()

# loading the skipgram model
skipgram_vocab = torch.load("./checkpoints/skipgram_vocab.pt")
skipgram_model_trained = SkipGram(vocab_size=len(skipgram_vocab), dims=EMBEDDING_DIMS)
skipgram_model_trained.load_state_dict(torch.load("./checkpoints/skipgram_epoch_10.pt"))
skipgram_model_trained.eval()

# # loading an untrained model, for comparison
# model_untrained = SkipGram(vocab_size=len(vocab), dims=EMBEDDING_DIMS)

SkipGram(
  (embeddings): Embedding(10001, 100)
  (linear): Linear(in_features=100, out_features=10001, bias=True)
)

In [4]:
# get_word_embedding(model=model, word="tree", vocab=vocab)

In [5]:
# Example usage
word = 'house'
# word = 'car'
# word = 'economics'
# word = 'fruit'
N = 10

top_n_cbow = find_most_similar_words(cbow_model_trained, word, cbow_vocab, N=N)
top_n_skipgram = find_most_similar_words(skipgram_model_trained, word, skipgram_vocab, N=N)

In [6]:
for w, sim in top_n_cbow:
    print(w, sim)

hotel 0.43081992864608765
husband 0.39525988698005676
death 0.37713193893432617
theater 0.37603211402893066
church 0.35864508152008057
home 0.3533965051174164
demons 0.34911173582077026
apartment 0.3458288013935089
execution 0.3455311059951782
grandfather 0.3419297933578491


In [7]:
for w, sim in top_n_skipgram:
    print(w, sim)

manor 0.5029898285865784
representatives 0.4920727014541626
parliament 0.49096861481666565
lady 0.4853280186653137
chamber 0.46166378259658813
hall 0.45709162950515747
houses 0.4455133378505707
family 0.44139552116394043
widow 0.43966275453567505
queen 0.4377342462539673
