In [1]:
#  step through years and visualize in PCA space

import mysql.connector as mysql
import pickle
import json

import numpy as np
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['svg.fonttype'] = 'none'

import time

### mysql client

In [12]:
db_name = 'test_pubmed'  # db name collisons? https://stackoverflow.com/questions/14011968/user-cant-access-a-database
client_config = {'unix_socket':'/home/brendanchambers/.sql.sock',
                'database': db_name,
                'use_pure': True}  # for python connector
output_path = '/project2/jevans/brendan/pubmed_data_processing/year_pmids/'

## load pre-fit pca model

In [3]:
pca_path = 'develop_samples/pca_model1.pkl'  # more located at /project2...
with open(pca_path, 'rb') as file:
    pca_model = pickle.load(file)

# load year publication pmids  & join to embeddings 

(todo join to text as well)

In [None]:
start_year = 1958
end_year = 2018
D_truncate = 300
path2dir = '/project2/jevans/brendan/pubmed_data_processing/year_pmids/'

year_pubs = {}
for year in range(start_year, end_year+1):
    
    print('{}...'.format(year))
    
    db = mysql.connect(**client_config)

    filename = 'pubmed_state_{}'.format(year)
    path2pmids = path2dir + filename
    with open(path2pmids,'r') as f:
        data = json.load(f)
    
    year_pub_pmids = data['publications']
    N_pubs = len(year_pub_pmids)
    print("N pubs: {}".format(N_pubs))
    del data # clean up
    
    str_fmt = ', '.join([str(pmid) for pmid in year_pub_pmids])
    
    sql = '''SELECT E.pmid, E.embedding
            FROM scibert_mean_embedding as E
            WHERE E.pmid IN ({})'''.format(str_fmt)
    
    start_time = time.time()
    cursor = db.cursor(buffered=False)
    cursor.execute(sql)
    end_time = time.time()
    elapsed = end_time - start_time
    print("SQL join executed in {} s".format(elapsed))

    start_time = time.time()
    pub_embeddings = []
    pub_pmids = []
    for i,row in enumerate(cursor):
        print_block_len = 100000
        if i % print_block_len == 0:
            print('fetched {} rows...'.format(print_block_len))
        pub_pmids.append(row[0])
        pub_embeddings.append(np.frombuffer(row[1],dtype='float64').tolist())
    cursor.close()
    end_time = time.time()
    elapsed = end_time - start_time
    print("SQL results fetched and cast in {} s".format(elapsed))
    
    start_time = time.time()
    #year_pubs[year] = pca_model.transform(pub_embeddings)[:,:D_truncate]
    end_time = time.time()
    elapsed = end_time - start_time
    print("pca transform finished in {} s".format(elapsed))
    
    start_time = time.time()
    path = output_path + 'publication_embeddings/' + str(year) + '.json'
    save_obj = {'pmids': pub_pmids,
                'embeddings': pub_embeddings}
    with open(path,'w') as f:
        json.dump(save_obj, f)
    end_time = time.time()
    elapsed = end_time - start_time
    print('finished writing output file in {} s...'.format(elapsed))
        
    print()

1958...
N pubs: 109432
SQL join executed in 0.38938164710998535 s
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 60.227108001708984 s
pca transform finished in 4.76837158203125e-07 s
finished writing output file in 117.28796410560608 s...

1959...
N pubs: 109952
SQL join executed in 0.46601271629333496 s
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 74.25713181495667 s
pca transform finished in 7.152557373046875e-07 s
finished writing output file in 119.27769827842712 s...

1960...
N pubs: 112169
SQL join executed in 0.9044733047485352 s
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 125.56717133522034 s
pca transform finished in 2.384185791015625e-07 s
finished writing output file in 119.87176775932312 s...

1961...
N pubs: 120124
SQL join executed in 0.6382238864898682 s
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 130.4963755607605 s
pca transform finished i

finished writing output file in 377.114905834198 s...

1987...
N pubs: 367778
SQL join executed in 1.5504462718963623 s
fetched 100000 rows...
fetched 100000 rows...
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 339.93264293670654 s
pca transform finished in 4.76837158203125e-07 s
finished writing output file in 397.4475622177124 s...

1988...
N pubs: 386677
SQL join executed in 3.1321213245391846 s
fetched 100000 rows...
fetched 100000 rows...
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 315.91997241973877 s
pca transform finished in 2.384185791015625e-07 s
finished writing output file in 413.0181543827057 s...

1989...
N pubs: 402283
SQL join executed in 1.9011433124542236 s
fetched 100000 rows...
fetched 100000 rows...
fetched 100000 rows...
fetched 100000 rows...
SQL results fetched and cast in 323.1846158504486 s
pca transform finished in 4.76837158203125e-07 s
finished writing output file in 419.6739354133606 s...



# load year citation pmids, join to embeddings

In [None]:
year_cites = {}

for year in range(start_year, end_year+1):
    
    print('{}...'.format(year))
    
    db = mysql.connect(**client_config)

    filename = 'pubmed_state_{}'.format(year)
    path2pmids = path2dir + filename
    with open(path2pmids,'r') as f:
        data = json.load(f)
    
    year_cite_pmids = data['citations']
    del data # clean up
    N_citations = len(year_cite_pmids)
    print("N citations: {}".format(N_citations))
    
    str_fmt = ', '.join([str(pmid) for pmid in year_cite_pmids])
    
    sql = '''SELECT E.pmid, E.embedding
            FROM scibert_mean_embedding as E
            WHERE E.pmid IN ({})'''.format(str_fmt)
    
    start_time = time.time()
    cursor = db.cursor(buffered=False)
    cursor.execute(sql)
    end_time = time.time()
    elapsed = end_time - start_time
    print("SQL join executed in {} s".format(elapsed))

    start_time = time.time()
    cite_embeddings = []
    cite_pmids = []
    for i,row in enumerate(cursor):
        print_block_len = 100000
        if i % print_block_len == 0:
            print('fetched {} rows...'.format(print_block_len))
        cite_pmids.append(row[0])
        cite_embeddings.append(np.frombuffer(row[1],dtype='float64').tolist())

    cursor.close()
    print('fetched')

    end_time = time.time()
    elapsed = end_time - start_time
    print("SQL results fetched and cast in {} s".format(elapsed))

    start_time = time.time()
    #year_cites[year] = pca_model.transform(cite_embeddings)[:,:D_truncate]
    end_time = time.time()
    elapsed = end_time - start_time
    print("pca transform finished in {} s".format(elapsed))
        
    start_time = time.time()
    path = output_path + 'citation_embeddings/' + str(year) + '.json'
    save_obj = {'pmids': cite_pmids,
                'embeddings': cite_embeddings}
    with open(path,'w') as f:
        json.dump(save_obj, f)
    end_time = time.time()
    elapsed = end_time - start_time
    print('finished writing output file in {} s'.format(elapsed))
    
    db.close()
    print()
    

### plot publications and citations

In [None]:
# todo - use this as raw material for a separate plotting script

def plot_pubs_and_cites(start_year, end_year):
    (f, ax) = plt.subplots(end_year - start_year + 1,
                       2,
                       sharex='all', sharey='all',
                       figsize=(4,4))

    for i_year, year in enumerate(range(start_year, end_year+1)):

        print(i_year, year)

        sns.kdeplot(year_pubs[year][:,0],
                    year_pubs[year][:,1],
                    ax=ax[i_year,0],
                    shade=True,
                    cmap='Blues')
        ax[i_year,0].set_title('published: year {}'.format(year))


        sns.kdeplot(year_cites[year][:,0],
                    year_cites[year][:,1],
                    ax=ax[i_year,1],
                    shade=True,
                    cmap='Reds')
        ax[i_year,1].set_title('cited: {}'.format(year))

    plt.savefig('publications and citations prototype {} - {}.png'.format(start_year, end_year))
    plt.savefig('publications and citations prototype {} - {}.svg'.format(start_year, end_year))
    plt.show()

In [None]:
plot_pubs_and_cites(start_year, end_year)