# Create Document Vectors

This notebook creates document vectors for each of the research papers based on the paper contents.

In [1]:
from cord.core import JSON_CATALOGS, BIORXIV_MEDRXIV, COMM_USE_SUBSET, CUSTOM_LICENSE, NONCOMM_USE_SUBSET, cord_support_dir
from cord.jsonpaper import load_json_cache
from cord import ResearchPapers
import pandas as pd
from pathlib import Path, PurePath
import numpy as np

In [2]:
VECTOR_SIZE = 20

## 1. Load the Cached JSON Index Tokens

Use the precached json index tokens.

In [3]:
json_tokens = []
for catalog in JSON_CATALOGS:
    json_cache = load_json_cache(catalog)
    json_tokens.append(json_cache)
    
json_tokens = pd.concat(json_tokens, ignore_index=True)

Loading json cache files for comm_use_subset
Loaded comm_use_subset json cache in 33 seconds
Loading json cache files for biorxiv_medrxiv
Loaded biorxiv_medrxiv json cache in 1 seconds
Loading json cache files for noncomm_use_subset
Loaded noncomm_use_subset json cache in 6 seconds
Loading json cache files for custom_license
Loaded custom_license json cache in 52 seconds


## 2. Extract the PMCID

In [4]:
json_tokens['pmcid'] = json_tokens.sha.str.extract('(PMC[0-9]+)\.xml')
json_tokens.loc[~json_tokens.pmcid.isnull(), 'sha'] = np.nan
json_tokens = json_tokens[['sha', 'pmcid', 'index_tokens']]

In [5]:
json_tokens.query("pmcid=='PMC1054884'")

Unnamed: 0,sha,pmcid,index_tokens


## 3. Train a Gensim Doc2vec Model

In [None]:
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(json_tokens.index_tokens)]
model = Doc2Vec(documents, vector_size=VECTOR_SIZE, window=2, min_count=1, workers=8)

## 4. Save Doc2VecModel

In [None]:
model.save(f'Doc2Vec_{VECTOR_SIZE}.model')

## 5. Create Document Vector

In [9]:
def get_vector(tokens):
    return model.infer_vector(tokens)

%timeit get_vector(json_tokens.loc[0].index_tokens)

1.59 ms ± 91.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [10]:
%time json_tokens['document_vector'] = json_tokens.index_tokens.apply(model.infer_vector)

Wall time: 19min 4s


In [2]:
from cord.core import DOCUMENT_VECTOR_PATH
import pandas as pd
document_vectors = pd.read_parquet(DOCUMENT_VECTOR_PATH)

## 6. Create Downsampled Vectors

In [None]:
docvector_arr = np.stack(json_tokens.document_vector.values)
RANDOM_STATE = 42

def kmean_labels(docvectors, n_clusters=6, random_state=RANDOM_STATE):
    print('Setting cluster labels')
    from sklearn.cluster import KMeans
    kmeans = KMeans(n_clusters=n_clusters,
                    random_state=random_state).fit(docvectors)
    return kmeans.labels_


def tsne_embeddings(docvectors, dimensions=2):
    print(f'Creating {dimensions}D  embeddings')
    from sklearn.manifold import TSNE
    tsne = TSNE(verbose=1,
                perplexity=10,
                early_exaggeration=24,
                n_components=dimensions,
                n_jobs=8,
                random_state=RANDOM_STATE,
                learning_rate=600)
    embeddings = tsne.fit_transform(docvectors)
    del tsne
    return embeddings

%time json_tokens['document_vector_2d'] = tsne_embeddings(docvector_arr, dimensions=2)
%time json_tokens['document_vector_1d'] = tsne_embeddings(docvector_arr, dimensions=1)
%time json_tokens['cluster_id'] = kmean_labels(docvector_arr, 7)

In [36]:
json_tokens

Unnamed: 0,sha,pmcid,index_tokens,document_vector,document_vector_2d,cluster_id,document_vector_1d
0,000b7d1517ceebb34e1e3e817695b6de03e2fa78,,"[s1, phylogeny, sequences, belonging, umrv, ph...","[0.56928027, -0.3666296, -0.20493843, 0.431051...",15.698082,1,-20.126677
1,00142f93c18b07350be89e96372d240372437ed9,,"[human, beings, constantly, exposed, myriad, p...","[3.3664315, -2.1276946, 1.6158434, 0.9182286, ...",31.039991,5,28.878820
2,0022796bb2112abd2e6423ba2d57751db06049fb,,"[pathogens, vectors, transported, rapidly, aro...","[0.27237293, 0.15642405, 2.5503929, 1.1305904,...",-50.606564,0,-54.790283
3,0031e47b76374e05a18c266bd1a1140e5eacb54f,,"[a1111111111, a1111111111, a1111111111, a11111...","[0.5768533, -3.854187, 0.072966725, 0.8637349,...",-69.332512,2,-7.735852
4,00326efcca0852dc6e39dc6b7786267e1bc4f194,,"[addition, preventative, care, nutritional, su...","[1.3273811, 0.4609563, 3.068578, -0.50986123, ...",-18.875193,0,-64.668060
...,...,...,...,...,...,...,...
52092,,,"[inactivated, virus, vaccines, inactivated, wk...","[1.418982, -0.4864095, 0.19998026, 0.23552166,...",8.962655,5,27.742414
52093,,,"[types, protein, microarrays, currently, types...","[-0.5313242, -1.7436063, 1.3774712, 0.01626098...",62.282139,4,64.295235
52094,,,[],"[0.0024406752, 0.010759468, 0.0051381686, 0.00...",9.226234,1,-89.850319
52095,,,[],"[0.0024406752, 0.010759468, 0.0051381686, 0.00...",6.759057,1,-89.850319


## 7. Save Document Vectors

In [31]:
docvector_savepath = Path(cord_support_dir()) / f'DocumentVectors_{VECTOR_SIZE}.pq'
json_vectors = json_tokens[['sha', 'pmcid', 'document_vector', 'document_vector_2d', 'document_vector_1d', 'cluster_id']]
json_vectors.to_parquet(Path(cord_support_dir()) / f'DocumentVectors_{VECTOR_SIZE}.pq')

In [33]:
pd.read_parquet(docvector_savepath)

Unnamed: 0,sha,pmcid,document_vector,document_vector_2d,document_vector_1d,cluster_id
0,000b7d1517ceebb34e1e3e817695b6de03e2fa78,,"[0.56928027, -0.3666296, -0.20493843, 0.431051...",15.698082,-20.126677,1
1,00142f93c18b07350be89e96372d240372437ed9,,"[3.3664315, -2.1276946, 1.6158434, 0.9182286, ...",31.039991,28.878820,5
2,0022796bb2112abd2e6423ba2d57751db06049fb,,"[0.27237293, 0.15642405, 2.5503929, 1.1305904,...",-50.606564,-54.790283,0
3,0031e47b76374e05a18c266bd1a1140e5eacb54f,,"[0.5768533, -3.854187, 0.072966725, 0.8637349,...",-69.332512,-7.735852,2
4,00326efcca0852dc6e39dc6b7786267e1bc4f194,,"[1.3273811, 0.4609563, 3.068578, -0.50986123, ...",-18.875193,-64.668060,0
...,...,...,...,...,...,...
52092,,,"[1.418982, -0.4864095, 0.19998026, 0.23552166,...",8.962655,27.742414,5
52093,,,"[-0.5313242, -1.7436063, 1.3774712, 0.01626098...",62.282139,64.295235,4
52094,,,"[0.0024406752, 0.010759468, 0.0051381686, 0.00...",9.226234,-89.850319,1
52095,,,"[0.0024406752, 0.010759468, 0.0051381686, 0.00...",6.759057,-89.850319,1
