In [1]:
import pymysql
import pickle
import json

import math
import numpy as np
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['svg.fonttype'] = 'none'

import time
import hdbscan


In [2]:
target_file = 'data_processing/try_track_communities.json'

START_YEAR = 2003  # ~500k total abstracts in dataset
END_YEAR = 2018    #   ~1m total abstracts in dataset
N_years = END_YEAR - START_YEAR + 1

sample_fraction = 0.02 # 0.2  # fraction of total pmids to sample

# pmids for each year (for convenience)
path2dir = '/home/brendan/FastData/pubmed2019/pubmed_data_processing/year_pmids/'  # knowledge-garden



In [3]:
# set up mysql connection

In [4]:

config_path = '/home/brendan/Projects/AttentionWildfires/attention_wildfires/mysql_config.json'
db_name = 'test_pubmed'  # db name collisons? https://stackoverflow.com/questions/14011968/user-cant-access-a-database
                        # todo should move this db_name into config file
with open(config_path, 'r') as f:
    config_data = json.load(f)
    
client_config = {'database': db_name,
                'user': config_data['user'],
                 'password': config_data['lock']}

## init db connection
db = pymysql.connect(**client_config)


In [5]:
# load the pre-fit umap model
#   todo do this in higher dimensionality

In [6]:
print('pop open pickle jar: umap model...')
umap_path = "/home/brendan/FastData/pubmed2019/pubmed_data_processing/dimensionality_reduction_models/umap2D/umap_model0.pkl"
with open(umap_path, 'rb') as file:
    umap_model = pickle.load(file)

pop open pickle jar: umap model...


In [7]:
# helper functions

In [8]:
def load_PMIDs(year):
    '''
    sample pmids from this year
        note: relies on parameters specified above
    '''
    filename = 'pubmed_state_{}'.format(year)
    path2pmids = path2dir + filename
    with open(path2pmids,'r') as f:
        data = json.load(f)

    year_pub_pmids = data['publications']
    N_pubs = len(year_pub_pmids)
    print("N pubs: {}".format(N_pubs))
    del data
    
    K_sample = int(N_pubs * sample_fraction)
    print("K samples: {}".format(K_sample))
    sample_pmids = np.random.choice(year_pub_pmids, K_sample)
    return sample_pmids


In [9]:
def get_embedding_vectors(sample_pmids):
    '''
    get embedding coordinates from database based on PMID list
    '''
    print('fetching embedding vectors from database...')
    start_time = time.time()
        
    str_fmt = ', '.join([str(pmid) for pmid in sample_pmids])
    sql = '''SELECT E.pmid, E.embedding
            FROM scibert_mean_embedding as E
            WHERE E.pmid IN ({})'''.format(str_fmt)
    cursor = db.cursor()
    cursor.execute(sql)
    
    pub_embeddings = []
    pub_pmids = []
    for i,row in enumerate(cursor):
        pub_pmids.append(row[0])
        pub_embeddings.append(np.frombuffer(row[1],dtype='float16').tolist())
    cursor.close()
    
    end_time = time.time()
    elapsed = end_time - start_time
    print("SQL query composed and executed in {} s".format(elapsed))
    
    return pub_pmids, pub_embeddings

def get_compressed_embedding_vectors(sample_pmids):
    '''
    calls get_embedding vectors
    then runs dimensionality reduction
    '''
    
    # return the pmids which have corresponding embeddings
    pmids, embeddings = get_embedding_vectors(sample_pmids)    
    print('compressing embedding vectors...')
    
    return pmids, umap_model.transform(embeddings)

In [10]:
def plot_clustering(xx, yy, labels):
    unique_labels = np.unique(labels)
    cluster_cmap = plt.cm.viridis(np.linspace(0,1,len(unique_labels)+1))
    
    (f, axs) = plt.subplots(len(unique_labels),
                           1,
                           sharex='all',
                           sharey='all',
                           figsize=(4, 4*len(unique_labels)))
        
    for i_label, label in enumerate(unique_labels):
                
        xx_ = [x for (i,x) in enumerate(xx) if labels[i]==label]
        yy_ = [y for (i,y) in enumerate(yy) if labels[i]==label]
        
        sns.kdeplot(xx_,
                        yy_,
                        shade=True,
                        shade_lowest=False,
                        color=cluster_cmap[i_label],
                        ax=axs[i_label])
        
        

In [11]:
# load pmids, sample, cluster

In [None]:
year_data = {}
year_data['dimensionality_reduction'] = umap_path  # version of dimensionality reduction used

'''
year_data
    dim_reduction - e.g. umap version
    year
        sample_pmids

'''

clusterers = {}
for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):

    year_data[year] = {} # init data structure
    
    # get sample PMIDs published this year
    sample_pmids = load_PMIDs(year)  # some of these don't have abstracts
    (pmids, embeddings) = get_compressed_embedding_vectors(sample_pmids)
    print("N abstracts fetched: {}".format(len(pmids)))
    year_data[year]['sample_pmids'] = pmids
    year_data[year]['embedding'] = embeddings

    
    # cluster the compressed embeddings
    clustering_scale = int(len(pmids) * 0.01) # 0.001
    min_samples_param = int(np.min([1000, clustering_scale]))
    print("scale: {}  | min_samples: {} ".format(clustering_scale, min_samples_param))
    ####################
    clusterers[i_year] = hdbscan.HDBSCAN(min_cluster_size=clustering_scale, # 500 for 25K # 1000 for 50K # 50 fro 2000
                            min_samples=min_samples_param,   # 500, 1000, 50
                            cluster_selection_method='leaf') # eom')  # euclidean distance
    clusterers[i_year].fit(year_data[year]['embedding'])  # samples x features

    # number of clusters
    print('num clusters: {}'.format(clusterers[i_year].labels_.max()+1))
    
    # plot clusters
    xx = year_data[year]['embedding'][:,0]
    yy = year_data[year]['embedding'][:,1]
    
    plot_clustering(xx,yy,clusterers[i_year].labels_)
    

N pubs: 583939
K samples: 11678
fetching embedding vectors from database...
SQL query composed and executed in 1.3236620426177979 s
compressing embedding vectors...


The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see http://numba.pydata.org/numba-doc/latest/user/parallel.html#diagnostics for help.

File "../../../../../../brendanchambers/.conda/envs/embedding-base/lib/python3.7/site-packages/umap/nndescent.py", line 123:
<source missing, REPL/exec in use?>

  state.func_ir.loc))
The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see http://numba.pydata.org/numba-doc/latest/user/parallel.html#diagnostics for help.

File "../../../../../../brendanchambers/.conda/envs/embedding-base/lib/python3.7/site-packages/umap/nndescent.py", line 134:
<source missing, REPL/exec in use?>

  state.func_ir.loc))


N abstracts fetched: 8931
scale: 89  | min_samples: 89 
num clusters: 12
(13,)
0 -1
1 0
2 1
3 2
4 3
5 4
6 5
7 6
8 7
9 8
10 9
11 10
12 11
N pubs: 619853
K samples: 12397
fetching embedding vectors from database...
SQL query composed and executed in 0.9254982471466064 s
compressing embedding vectors...
N abstracts fetched: 9649
scale: 96  | min_samples: 96 
num clusters: 15
(16,)
0 -1
1 0
2 1
3 2
4 3
5 4
6 5
7 6
8 7
9 8
10 9
11 10
12 11
13 12
14 13
15 14
N pubs: 656109
K samples: 13122
fetching embedding vectors from database...
SQL query composed and executed in 1.9413821697235107 s
compressing embedding vectors...
N abstracts fetched: 10412
scale: 104  | min_samples: 104 
num clusters: 13
(14,)
0 -1
1 0
2 1
3 2
4 3
5 4
6 5
7 6
8 7
9 8
10 9
11 10
12 11
13 12
N pubs: 684653
K samples: 13693
fetching embedding vectors from database...
SQL query composed and executed in 1.0668065547943115 s
compressing embedding vectors...
N abstracts fetched: 10593
scale: 105  | min_samples: 105 
num clus

In [None]:
plot_clustering(xx,yy,clusterers[i_year].labels_)

In [None]:
# plot the samples from each year

(f, axs) = plt.subplots(1,
                           N_years,
                           sharex='all', sharey='all',
                           figsize=(N_years,1))

for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    
    xx = year_data[year]['embedding'][:,0]
    yy = year_data[year]['embedding'][:,1]
    
    sns.kdeplot(xx,
                        yy,
                        shade=True,
                        shade_lowest=False,
                        cmap='Blues',
                       ax=axs[i_year])
    plt.title('pubs: year {}'.format(year))