In [1]:
from gensim.models import Word2Vec
from gensim.models import KeyedVectors

In [2]:
def load_glove_model(glove_file_path):
    glove_model = KeyedVectors.load_word2vec_format(glove_file_path, binary=False, no_header=True)
    return glove_model

def get_corpus():
    with open("../data/processed/glove_corpus", "r") as f:
        sentences = f.readlines()
    return sentences

In [3]:
sentences = get_corpus()
tokens = [sent.split() for sent in sentences]

In [4]:
base_model = Word2Vec(vector_size=300, min_count = 20, epochs=20)
base_model.build_vocab(tokens)
total_examples = base_model.corpus_count

In [5]:
base_model.wv.most_similar('#person1#', topn=10)

[('perfect!', 0.19918687641620636),
 ('peking', 0.18685618042945862),
 ('speaking', 0.1850905865430832),
 ('inside.', 0.1844509094953537),
 ('chinese.', 0.17851804196834564),
 ('24', 0.17759983241558075),
 ('cake', 0.17517882585525513),
 ('help', 0.17481420934200287),
 ('education.', 0.17363373935222626),
 ('test', 0.17225077748298645)]

In [6]:
corpus_path = '../embeds/GloVe/glove.corpus.300d.txt'
corpus_model = load_glove_model(corpus_path)
base_model.build_vocab([list(corpus_model.key_to_index.keys())], update=True)

In [7]:
corpus_model.most_similar('#person1#', topn=10)

[('tells', 0.5800400376319885),
 ('asks', 0.5097846984863281),
 ('#person2#', 0.44948500394821167),
 ('thinks', 0.4351900815963745),
 ('suggests', 0.42950737476348877),
 ('<sos>', 0.40993642807006836),
 ('recommends', 0.4068792760372162),
 ('s', 0.3993434011936188),
 ("#person2#'", 0.3977951109409332),
 ('wants', 0.3972680866718292)]

In [8]:
base_model.train(tokens, total_examples=total_examples, epochs=base_model.epochs)
base_model_wv = base_model.wv


In [9]:
base_model_wv.most_similar('#person1#', topn=10)

[('#person2#', 0.7265139818191528),
 ('sam', 0.663916289806366),
 ('kate', 0.624889075756073),
 ('#person3#', 0.6135007739067078),
 ('mike', 0.6108843684196472),
 ('shirley', 0.5966283082962036),
 ('jack', 0.5953760743141174),
 ('alice', 0.5870617032051086),
 ('ann', 0.5827251672744751),
 ('lucy', 0.5810714364051819)]

In [10]:
base_model_wv.save_word2vec_format('../models/GloVe-Word2Vec/glove.bin')