In [1]:
import os
import gensim
from gensim.utils import simple_preprocess
from gensim.models.doc2vec import TaggedDocument

data_dir = os.path.join(gensim.__path__[0], 'test', 'test_data')
train_file_path = os.path.join(data_dir, 'lee_background.cor')
test_file_path = os.path.join(data_dir, 'lee.cor')

In [57]:
from tqdm import trange
from gensim.models import Doc2Vec
from gensim.models.doc2vec import TaggedLineDocument

# each line is a single document
train_corpus = TaggedLineDocument(train_file_path)

model = Doc2Vec(min_count = 2, size = 100, workers = 7)
model.build_vocab(train_corpus)
for _ in trange(8):
    model.train(train_corpus)
    model.alpha -= 0.002  # decrease the learning rate
    model.min_alpha = model.alpha  # fix the learning rate, no decay
    
# If you’re finished training a model (=no more updates, only querying), you can do
# to trim unneeded model memory
# model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True)

100%|██████████| 8/8 [00:01<00:00,  4.20it/s]


In [59]:
from collections import Counter

ranks = []
for idx, doc in enumerate(train_corpus):
    # each line is a TaggedDocument namedtuple, where we can
    # access the words attribute and infer their word vectors
    inferred_vector = model.infer_vector(doc.words)
    sims = model.docvecs.most_similar([inferred_vector], topn = model.corpus_count)
    rank = [doc_id for doc_id, _ in sims].index(idx)
    ranks.append(rank)

Counter(ranks)

Counter({0: 292, 1: 8})

In [61]:
inferred_vector

array([-0.58224672, -0.43096396, -0.33039987, -0.06532005,  0.47400451,
       -0.48527843,  0.01416627, -0.2556819 ,  0.56022859,  0.05959554,
       -0.75745648,  0.81947565, -0.0115818 , -0.21682028,  0.94731522,
       -0.40724245,  0.40448874, -0.54299033,  0.05646265, -0.55159432,
        0.34756666, -0.891092  ,  0.41050884,  0.19965491,  0.51197499,
        0.76365852, -0.42333227, -0.02293444,  0.06794542,  1.10781157,
        0.07878274,  0.09709717,  1.52739763,  1.52857971,  0.63935286,
        0.67650592, -0.30052426,  0.44312361, -1.32764304, -0.34013611,
       -0.75963336, -0.09873898, -0.78795689,  0.2531268 ,  0.15569344,
        0.29239166,  0.16786104, -0.95119458,  1.06164527, -0.52167094,
        0.22298038,  0.4879244 , -0.73503453, -0.3620781 , -0.82793939,
       -0.43032843,  0.34696817,  0.46614733, -0.31515405,  0.1487086 ,
       -0.72470874,  0.60460603,  0.20055507,  1.0842762 , -0.05671849,
        0.67490703,  0.47700262,  0.22134401, -0.21704741,  0.12

In [2]:
def read_corpus(file_path, tag = True):
    """For training data, add tags"""
    with open(file_path, encoding = 'iso-8859-1') as f:
        for i, line in enumerate(f):
            # simple_preprocess
            # tokenize text into individual words, 
            # remove punctuation, set to lowercase
            preprocessed = simple_preprocess(line)
            if tag:
                yield TaggedDocument(preprocessed, [i])
            else:
                yield preprocessed
                
train_corpus = list(read_corpus(train_file_path))
test_corpus = list(read_corpus(test_file_path, tag = False))

In [3]:
train_corpus[0]

TaggedDocument(words=['hundreds', 'of', 'people', 'have', 'been', 'forced', 'to', 'vacate', 'their', 'homes', 'in', 'the', 'southern', 'highlands', 'of', 'new', 'south', 'wales', 'as', 'strong', 'winds', 'today', 'pushed', 'huge', 'bushfire', 'towards', 'the', 'town', 'of', 'hill', 'top', 'new', 'blaze', 'near', 'goulburn', 'south', 'west', 'of', 'sydney', 'has', 'forced', 'the', 'closure', 'of', 'the', 'hume', 'highway', 'at', 'about', 'pm', 'aedt', 'marked', 'deterioration', 'in', 'the', 'weather', 'as', 'storm', 'cell', 'moved', 'east', 'across', 'the', 'blue', 'mountains', 'forced', 'authorities', 'to', 'make', 'decision', 'to', 'evacuate', 'people', 'from', 'homes', 'in', 'outlying', 'streets', 'at', 'hill', 'top', 'in', 'the', 'new', 'south', 'wales', 'southern', 'highlands', 'an', 'estimated', 'residents', 'have', 'left', 'their', 'homes', 'for', 'nearby', 'mittagong', 'the', 'new', 'south', 'wales', 'rural', 'fire', 'service', 'says', 'the', 'weather', 'conditions', 'which', 'c