In [1]:
import pymysql
import pickle
import json

import math
import numpy as np
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

import os
import time
import hdbscan


In [2]:
processed_data_dir = 'track_communities_data_processing/'
figures_dir = 'track_communities_figures/'
save_prefix = 'year_clusters_fract2_sample1'

START_YEAR = 1998  # ~500k total abstracts in dataset
END_YEAR = 2018    #   ~1m total abstracts in dataset
N_years = END_YEAR - START_YEAR + 1

sample_fraction = 0.2  # 0.2  # fraction of total pmids to sample

# pmids for each year (for convenience)
path2dir = '/home/brendan/FastData/pubmed2019/pubmed_data_processing/year_pmids/'  # knowledge-garden



In [3]:
# set up mysql connection

In [4]:

config_path = '/home/brendan/Projects/AttentionWildfires/attention_wildfires/mysql_config.json'
db_name = 'test_pubmed'  # db name collisons? https://stackoverflow.com/questions/14011968/user-cant-access-a-database
                        # todo should move this db_name into config file
with open(config_path, 'r') as f:
    config_data = json.load(f)
    
client_config = {'database': db_name,
                'user': config_data['user'],
                 'password': config_data['lock']}

## init db connection
db = pymysql.connect(**client_config)


In [5]:
# load the pre-fit umap model
#   todo do this in higher dimensionality

In [6]:
print('pop open pickle jar: umap model...')
umap_path = "/home/brendan/FastData/pubmed2019/pubmed_data_processing/dimensionality_reduction_models/umap2D/umap_model0.pkl"
with open(umap_path, 'rb') as file:
    umap_model = pickle.load(file)

pop open pickle jar: umap model...


In [7]:
# helper functions

In [8]:
def load_PMIDs(year):
    '''
    sample pmids from this year
        note: relies on parameters specified above
    '''
    filename = 'pubmed_state_{}'.format(year)
    path2pmids = path2dir + filename
    with open(path2pmids,'r') as f:
        data = json.load(f)

    year_pub_pmids = data['publications']
    N_pubs = len(year_pub_pmids)
    print("N pubs: {}".format(N_pubs))
    del data
    
    K_sample = int(N_pubs * sample_fraction)
    print("K samples: {}".format(K_sample))
    sample_pmids = np.random.choice(year_pub_pmids, K_sample)
    return sample_pmids


In [9]:
def get_embedding_vectors(sample_pmids):
    '''
    get embedding coordinates from database based on PMID list
    '''
    print('fetching embedding vectors from database...')
    start_time = time.time()
        
    str_fmt = ', '.join([str(pmid) for pmid in sample_pmids])
    sql = '''SELECT E.pmid, E.embedding
            FROM scibert_mean_embedding as E
            WHERE E.pmid IN ({})'''.format(str_fmt)
    cursor = db.cursor()
    cursor.execute(sql)
    
    pub_embeddings = []
    pub_pmids = []
    for i,row in enumerate(cursor):
        pub_pmids.append(row[0])
        pub_embeddings.append(np.frombuffer(row[1],dtype='float16').tolist())
    cursor.close()
    
    end_time = time.time()
    elapsed = end_time - start_time
    print("SQL query composed and executed in {} s".format(elapsed))
    
    return pub_pmids, pub_embeddings

def get_compressed_embedding_vectors(sample_pmids):
    '''
    calls get_embedding vectors
    then runs dimensionality reduction
    '''
    
    # return the pmids which have corresponding embeddings
    pmids, embeddings = get_embedding_vectors(sample_pmids)    
    print('compressing embedding vectors...')
    
    return pmids, umap_model.transform(embeddings)

In [10]:
R_MAX = 10
def serialize_point(x,y):
    '''
    take a point in R2 and assign a score,
        resulting in an ordering of points on the plane
    '''
    theta = math.atan2(x,y)
    theta = theta / math.pi # rescale onto (-1,1)
    if theta < 0:  # restitch 0-angle
        theta = theta + 2
    theta = round(theta,1) # limit precision of theta and break ties with r
    r = math.sqrt(x*x + y*y)
    r = r / R_MAX  # rescale onto (0,1)
    score = 10*theta + r  # use angle first, break ties with radius
    return score

def normalize_cluster_labels(xx, yy, old_labels):
    '''
        take a set of old labels and positions
        return a set of labels assigned in a fixed order, 
            for easer comparison downstream
    '''
    
    # number of clusters
    N_clusters = int(old_labels.max()+1)
    print('num clusters: {}'.format(N_clusters))
    
    # get summary coordinate for each cluster
    cluster_median = np.zeros( (N_clusters, 2))  # x,y (would be better to use median though)
    for i_cluster in range(N_clusters):
        total_points = 0
        vector_sum = np.zeros((1,2))
        
        convenience_xx, convenience_yy = [], []
        for i_label, label in enumerate(old_labels):
            # do nothing for labels of -1
            if label==i_cluster:
                convenience_xx.append(xx[i_label])
                convenience_yy.append(yy[i_label])
                
        median_x = np.median(convenience_xx)
        median_y = np.median(convenience_yy)
        cluster_median[i_cluster,:] = (median_x, median_y) 
    
    # use summary coordinate to order the clusters based on a serialization function
    cluster_scores = []
    for i_cluster in range(N_clusters):
        score = serialize_point(cluster_median[i_cluster,0],
                               cluster_median[i_cluster,1])
        cluster_scores.append(score)
        
    # sort ascending
    print(cluster_scores)
    new_labels = np.argsort(np.argsort(cluster_scores))
    print(new_labels)   
    
    # re-assign labels after normalizing integer names for consistency
    num_samples = len(old_labels)
    newlabels_ = np.zeros( (num_samples,) ).astype(int)
    for i_label, old_label in enumerate(old_labels):
        if old_label==-1:
            newlabels_[i_label] = -1  # populate new labels
        else:
            newlabels_[i_label] = new_labels[int(old_label)]  # cast as integer (temp hack)
    
    return newlabels_
        

In [11]:
def plot_clustering(xx, yy, labels, title):
    unique_labels = np.unique(labels)
    cluster_cmap = plt.cm.viridis(np.linspace(0,1,len(unique_labels)+1))
    
    (f, axs) = plt.subplots(1,
                            3,
                           sharex='all',
                           sharey='all',
                           figsize=(10,4))
    
    #kwargs = {'levels': np.arange(0, 0.05, 0.001)}  # specify bins for colormaps
        
    # plot full distribution on 1st subplot
    sns.kdeplot(xx,
                yy,
                n_levels=100,
                shade=True,
                shade_lowest=True,
                ax=axs[0],
                cbar=True,
                cbar_kws={'orientation':'horizontal'})
    axs[0].set_title('{}: sample'.format(title))
    
    
    # plot clusters on 2nd subplot    
    #   i.e. clusters 0 through N_clusters
    for i_label, label in enumerate(unique_labels):
        
        if label >= 0:
                
            xx_ = [x for (i,x) in enumerate(xx) if labels[i]==label]
            yy_ = [y for (i,y) in enumerate(yy) if labels[i]==label]

            if label==0:  # show a colorbar for the first cluster
                sns.kdeplot(xx_,
                            yy_,
                            n_levels=10,
                            shade=True,
                            shade_lowest=False,
                            color=cluster_cmap[i_label],
                            ax=axs[1],
                            cbar=True,
                            cbar_kws={'orientation':'horizontal'})
            else:  # don't show colorbars for the other clusters
                sns.kdeplot(xx_,
                            yy_,
                            n_levels=10,
                            shade=True,
                            shade_lowest=False,
                            color=cluster_cmap[i_label],
                            ax=axs[1])
    axs[1].set_title('{}: communities'.format(title))
        
    # plot distribution of non-clustered samples on 3rd subplot
    #    i.e. the -1 labeled clusters
    xx_ = [x for (i,x) in enumerate(xx) if labels[i]==-1]
    yy_ = [y for (i,y) in enumerate(yy) if labels[i]==-1]
    sns.kdeplot(xx_,
                yy_,
                n_levels=100,
                shade=True,
                shade_lowest=True,
                color=[0.1, 0.1, 0.1, 1.0],
                ax=axs[2],
                cbar=True,
                cbar_kws={'orientation':'horizontal'})
    axs[2].set_title('{}: remainder'.format(title))
    
    plt.savefig('{}.png'.format(title))
    plt.savefig('{}.pdf'.format(title))
    plt.savefig('{}.eps'.format(title))
    
    
        
        

In [12]:
# load pmids, sample, cluster

In [13]:
year_data = {}
year_data['dimensionality_reduction'] = umap_path  # version of dimensionality reduction used
year_data['sample_fraction'] = 0.2
'''
year_data
    dim_reduction - e.g. umap version
    year
        sample_pmids

'''

for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    print("year {}".format(year))
    year_data[year] = {} # init data structure
    
    # get sample PMIDs published this year
    sample_pmids = load_PMIDs(year)  # some of these don't have abstracts
    (pmids, embeddings) = get_compressed_embedding_vectors(sample_pmids)
    print("N abstracts fetched: {}".format(len(pmids)))
    year_data[year]['sample_pmids'] = np.copy(pmids).tolist()
    year_data[year]['embedding'] = np.copy(embeddings).tolist()


year 1998
N pubs: 474479
K samples: 94895
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 50.77551794052124 s
compressing embedding vectors...


The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see http://numba.pydata.org/numba-doc/latest/user/parallel.html#diagnostics for help.

File "../../../../../../brendanchambers/.conda/envs/embedding-base/lib/python3.7/site-packages/umap/nndescent.py", line 123:
<source missing, REPL/exec in use?>

  state.func_ir.loc))
The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see http://numba.pydata.org/numba-doc/latest/user/parallel.html#diagnostics for help.

File "../../../../../../brendanchambers/.conda/envs/embedding-base/lib/python3.7/site-packages/umap/nndescent.py", line 134:
<source missing, REPL/exec in use?>

  state.func_ir.loc))


N abstracts fetched: 63135
year 1999
N pubs: 493485
K samples: 98697
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 48.97369837760925 s
compressing embedding vectors...
N abstracts fetched: 64424
year 2000
N pubs: 530246
K samples: 106049
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 48.99577808380127 s
compressing embedding vectors...
N abstracts fetched: 71561
year 2001
N pubs: 543554
K samples: 108710
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 49.539939403533936 s
compressing embedding vectors...
N abstracts fetched: 75838
year 2002
N pubs: 558647
K samples: 111729
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 47.79432559013367 s
compressing embedding vectors...
N abstracts fetched: 78488
year 2003
N pubs: 583939
K samples: 116787
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 51.981850147247314 s
compressing embedding vectors...
N abstracts fetched: 81910
year 2004
N pubs: 619853
K samples: 123970
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 58.036935329437256 s
compressing embedding vectors...
N abstracts fetched: 88626
year 2005
N pubs: 656109
K samples: 131221
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 54.89845681190491 s
compressing embedding vectors...
N abstracts fetched: 94950
year 2006
N pubs: 684653
K samples: 136930
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 59.72939085960388 s
compressing embedding vectors...
N abstracts fetched: 96980
year 2007
N pubs: 710132
K samples: 142026
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 53.0354163646698 s
compressing embedding vectors...
N abstracts fetched: 102377
year 2008
N pubs: 750874
K samples: 150174
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 61.55712294578552 s
compressing embedding vectors...
N abstracts fetched: 112057
year 2009
N pubs: 784150
K samples: 156830
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 93.40994262695312 s
compressing embedding vectors...
N abstracts fetched: 116634
year 2010
N pubs: 822964
K samples: 164592
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 109.50313663482666 s
compressing embedding vectors...
N abstracts fetched: 120582
year 2011
N pubs: 875685
K samples: 175137
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 83.12316060066223 s
compressing embedding vectors...
N abstracts fetched: 125977
year 2012
N pubs: 939925
K samples: 187985
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 55.15540385246277 s
compressing embedding vectors...
N abstracts fetched: 141603
year 2013
N pubs: 994460
K samples: 198892
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 53.73466491699219 s
compressing embedding vectors...
N abstracts fetched: 147829
year 2014
N pubs: 1041775
K samples: 208355
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 89.16910409927368 s
compressing embedding vectors...
N abstracts fetched: 155492
year 2015
N pubs: 1089303
K samples: 217860
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 52.713213205337524 s
compressing embedding vectors...
N abstracts fetched: 162623
year 2016
N pubs: 1110653
K samples: 222130
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 58.220489263534546 s
compressing embedding vectors...
N abstracts fetched: 168420
year 2017
N pubs: 1120990
K samples: 224198
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 54.531826972961426 s
compressing embedding vectors...
N abstracts fetched: 166673
year 2018
N pubs: 1205220
K samples: 241044
fetching embedding vectors from database...


  result = self._query(query)


SQL query composed and executed in 58.36372113227844 s
compressing embedding vectors...
N abstracts fetched: 168321


In [14]:
target_path = processed_data_dir + save_prefix + '_samples.json'

with open(target_path,'w') as f:
    json.dump(year_data, f)  # already turned numpy to lists
    
# send back to numpy array
for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    year_data[year]['embedding'] = np.asarray(year_data[year]['embedding'])

In [15]:
# cluster the compressed embeddings

clusterers = {}
for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    print("year {}".format(year))
    
    pmids = year_data[year]['sample_pmids']
    clustering_scale = int(len(pmids) * 0.005)
    min_samples = int(len(pmids) * 0.001)
    min_samples_param = int(np.min([1000, min_samples])) # hdbscan can't handle really big values here
    print("scale: {}  | min_samples: {} ".format(clustering_scale, min_samples_param))
    ####################
    clusterers[i_year] = hdbscan.HDBSCAN(min_cluster_size=clustering_scale, # 500 for 25K # 1000 for 50K # 50 fro 2000
                            min_samples=min_samples_param,   # 500, 1000, 50
                            cluster_selection_method='leaf') # eom')  # euclidean distance
    clusterers[i_year].fit(year_data[year]['embedding'])  # samples x features

    # number of clusters
    print('num clusters: {}'.format(clusterers[i_year].labels_.max()+1))
    



year 1998
scale: 315  | min_samples: 63 
num clusters: 23
year 1999
scale: 322  | min_samples: 64 
num clusters: 26
year 2000
scale: 357  | min_samples: 71 
num clusters: 25
year 2001
scale: 379  | min_samples: 75 
num clusters: 27
year 2002
scale: 392  | min_samples: 78 
num clusters: 30
year 2003
scale: 409  | min_samples: 81 
num clusters: 32
year 2004
scale: 443  | min_samples: 88 
num clusters: 31
year 2005
scale: 474  | min_samples: 94 
num clusters: 35
year 2006
scale: 484  | min_samples: 96 
num clusters: 33
year 2007
scale: 511  | min_samples: 102 
num clusters: 32
year 2008
scale: 560  | min_samples: 112 
num clusters: 33
year 2009
scale: 583  | min_samples: 116 
num clusters: 32
year 2010
scale: 602  | min_samples: 120 
num clusters: 33
year 2011
scale: 629  | min_samples: 125 
num clusters: 32
year 2012
scale: 708  | min_samples: 141 
num clusters: 31
year 2013
scale: 739  | min_samples: 147 
num clusters: 32
year 2014
scale: 777  | min_samples: 155 
num clusters: 27
year 2

In [16]:
# normalize cluster labels

for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    
    xx = year_data[year]['embedding'][:,0]
    yy = year_data[year]['embedding'][:,1]
    
    new_labels = normalize_cluster_labels(xx,
                                          yy,
                                          clusterers[i_year].labels_)

    # overwrite old labels with normalized labels
    for i_label, old_label in enumerate(clusterers[i_year].labels_):  # rewrite with new labels
        clusterers[i_year].labels_[i_label] = new_labels[i_label]

num clusters: 23
[7.431275044852526, 18.37368799929633, 18.44055308159413, 5.435716043438744, 10.274855429154066, 6.238982837979988, 6.102024852273956, 4.494637012223374, 3.4534421491577403, 11.480239983124388, 12.471728989097835, 15.506345451652942, 15.39146579617445, 16.227225170551613, 17.36647529043253, 16.441410572963125, 16.507510163647073, 20.370501774323248, 14.172055091836501, 2.4128024391805027, 1.3889083914991067, 2.2652266848797606, 2.15657704783825]
[ 9 20 21  6 10  8  7  5  4 11 12 15 14 16 19 17 18 22 13  3  0  2  1]
num clusters: 26
[7.425017553256766, 6.534652063380986, 18.444079472312904, 18.376078208153142, 10.293250253290411, 11.21774412801829, 5.427683570145447, 6.242023556230721, 6.1081451424642825, 12.496370194452185, 11.479389539811846, 12.369370692221647, 15.502532101529951, 16.23152282229581, 20.37304826009962, 14.185169167019826, 16.4372098525331, 17.367178264454875, 16.50847522789759, 15.39444039639604, 4.498833151331205, 1.4071892818948992, 2.47653316811709

[18.3706840739877, 18.493984662791163, 18.431981624372533, 5.65407767457283, 6.1266196513485145, 10.28562855094885, 5.6823656702587915, 7.395791978027465, 4.489143221272016, 4.3270922032470995, 17.364711631809826, 6.253804385978788, 6.555615536104704, 6.360282421777163, 5.422892741589585, 5.553155739709818, 5.486911388002604, 14.505647198770943, 15.382523235044799, 16.4469790004033, 16.508757623392125, 14.23320655942181, 2.4492574174392248, 20.38904299669292, 13.366649977098668, 13.357507314414287, 17.218422145633433, 19.258689690545793, 11.49554929073144, 1.1875538113646642, 1.3808618810338595, 11.436840853898833, 12.441081862820573]
[28 30 29  8 10 15  9 14  4  3 27 11 13 12  5  7  6 22 23 24 25 21  2 32
 20 19 26 31 17  0  1 16 18]
num clusters: 32
[18.371524319947937, 18.431133510195334, 18.494237928186323, 5.65371580078311, 5.66635099935704, 10.279387516871255, 4.492384015798836, 6.13566939724812, 20.382581827777436, 4.330094940989494, 7.394935198905229, 6.356125821592397, 6.25565

In [17]:
    # TODO
    # measure center, radius, # pmids

In [None]:
# plot clusters

for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    
    xx = year_data[year]['embedding'][:,0]
    yy = year_data[year]['embedding'][:,1]
    
    title = '{}{}_{}'.format(figures_dir, save_prefix, year)
    plot_clustering(xx,yy,clusterers[i_year].labels_, title)
    
    

In [None]:
# save things

#target_path = os.path.join(target_prefix, 'clusterers.json')
#with open(target_path,'w') as f:
#    json.dump(clusterers, f)

for i_year, year in enumerate(range(START_YEAR, END_YEAR+1)):
    print(i_year, year)
    year_data[year]['labels'] = clusterers[i_year].labels_.tolist()
    try:
        year_data[year]['embedding'] = year_data[year]['embedding'].tolist()
    except:
        print('embedding is already a list')
    try:
        year_data[year]['sample_pmids'] = year_data[year]['sample_pmids'].tolist()
    except:
        print('sample_pmids is already a list')
    
target_path = processed_data_dir + save_prefix + '_samples.json'
with open(target_path,'w') as f:
    json.dump(year_data, f)
    
print('success!')

In [None]:
# figure todo - dark background on clusters
#    average subtraction over multiple samples?