In [1]:
import gensim
import nltk
import pandas as pd
import numpy as np
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

nltk.download('stopwords')
nltk.download('omw-1.4')
nltk.download('punkt')
nltk.download('wordnet')
stop_words = set(stopwords.words('english'))
lemma = WordNetLemmatizer()

In [2]:
def prep(rowitem):
    if len(str(rowitem).split()) < 10:
        return np.nan
    rowitem = word_tokenize(str(rowitem))
    rowitem = [i.lower() for i in rowitem if i.isalpha()]
    rowitem = [ i for i in rowitem if i not in stop_words ]
    rowitem = ' '.join([ lemma.lemmatize(i) for i in rowitem ])

    return rowitem

In [3]:
df = pd.read_csv('../input/to-emb-or-not-to-emb/medium_test_data.csv').drop(columns=['Unnamed: 0', 'label'])
df.apply(prep)
df = df.dropna()

In [4]:
data = df.content.values.tolist()
tagged_data = [gensim.models.doc2vec.TaggedDocument(words=word_tokenize(_d.lower()), tags=[str(i)]) for i, _d in enumerate(data)]

In [6]:
max_epochs = 5
vec_size = 20
alpha = 0.025

model = gensim.models.doc2vec.Doc2Vec(vector_size=vec_size,
                alpha=alpha, 
                min_alpha=0.00025,
                min_count=1,
                dm =1)
  
model.build_vocab(tagged_data) # Use data_for_training instead

for epoch in range(max_epochs):
    print('iteration {0}'.format(epoch))
    model.train(tagged_data, #  Use data_for_training instead
                total_examples=model.corpus_count,
                epochs=model.epochs)
    # decrease the learning rate
    model.alpha -= 0.0002
    # fix the learning rate, no decay
    model.min_alpha = model.alpha

model.save("d2v.model")
print("Model Saved")

In [10]:
model= gensim.models.doc2vec.Doc2Vec.load("d2v.model")

test_data = word_tokenize("the 8th edition of international day of yoga will see many firsts but one of".lower())
v1 = model.infer_vector(test_data)
# print("V1_infer", v1)

similar_doc = model.docvecs.most_similar(positive=[v1])
print(similar_doc)

In [None]:
while True:
    pass