Evaluate Representations pipeline: 'Gather'
            consolidate encoding outputs (csv files) to mysql tables

In [1]:
import numpy as np
import csv
import os
import time
import json
import mysql.connector as mysql

In [2]:
#  pure mysql python connector approach to the mysql connection
#   offers easy control over the BLOB type embeddings
client_config = {'unix_socket':'/home/brendanchambers/.sql.sock',
                            'database':'test_pubmed',
                            'use_pure':True}  # pure python mode is important for reading blobs of bytes
write_size = 10000  # write to db every 10K rows



path2embeddings = '/project2/jevans/brendan/pubmed_data_processing/validation_sets/jneurophysiol_vs_neuroimage_results/'

# these are in single csv files - not spread across multiple chunk files
'''
#BERT approaches:
embedding_names = ['scibert__longtoken_mean.csv',  # does not include [CLS]
                   'scibert__tokenwise_mean.csv',  # does not include [CLS]
                   'scibert__cls.csv',
                   'vanilla__longtoken_mean.csv',  # does not include [CLS]
                   'vanilla__tokenwise_mean.csv', # does not include [cls]
                   'vanilla__cls.csv']

# insert into mysql tables called:
table_names = ['emb_scibert_longtokens_mean',
                 'emb_scibert_tokens_mean',
                 'emb_scibert_cls',
                 'emb_bert_longtokens_mean', 
                 'emb_bert_tokens_mean',  
                 'emb_bert_cls']
'''
#word2vec approaches
embedding_names = ['w2v_pubmed.csv',
                  'w2v_pubmed_pmc.csv',
                  'w2v_wikipedia_pubmed_pmc.csv']
# insert into mysql tables called:
table_names = ['emb_w2v_pm',
              'emb_w2v_pm_pmc',
              'emb_w2v_wiki_pm_pmc']


# create tables

In [3]:
for table_name in table_names:
    
    print('creating table {}...'.format(table_name))

    try:
        db = mysql.connect(**client_config)
        sql = '''CREATE TABLE {}
                (pmid int NOT NULL,
                embedding BLOB NOT NULL,
                PRIMARY KEY (pmid))'''.format(table_name)
        cursor = db.cursor()
        cursor.execute(sql)
        cursor.close()
        db.commit()
        db.close()
    except Exception as e:
        # table is probably already available
        print('Warning during table creation:   {}'.format(e))



creating table emb_scibert_longtokens_mean...
creating table emb_scibert_tokens_mean...
creating table emb_scibert_cls...
creating table emb_bert_longtokens_mean...
creating table emb_bert_tokens_mean...
creating table emb_bert_cls...


# helper function

In [4]:
def write_data_to_db(entries, table_name):
    db = mysql.connect(**client_config)
    sql = '''INSERT INTO {} (pmid,embedding)
             VALUES (%s, %s)
             ON DUPLICATE KEY UPDATE
             pmid=values(pmid), embedding=values(embedding)'''.format(table_name)
    cursor = db.cursor()
    cursor.executemany(sql, entries)
    cursor.close()
    db.commit()

# insert into db

In [5]:
for table_name, embedding_name in zip(table_names, embedding_names):
    
    print('inserting {} into {} table...'.format(embedding_name, table_name))
    start_time = time.time()

    data = []
    with open(path2embeddings + embedding_name) as f:
        csvreader = csv.reader(f, delimiter=' ')
        # note: there is no header in these csv files
        
        for idx, row in enumerate(csvreader):
            pmid = int(row[0])
            
            embedding_blob = np.array(json.loads(row[1])).tobytes(
                                                        order='C')
            data.append((pmid, embedding_blob))
            
            if len(data) > write_size:  # write to db intermittently
                print('dumping to db')
                write_data_to_db(data, table_name) # csv -> mysql
                data = [] # reset data  after writing
    
    write_data_to_db(data, table_name)  # empty the buffer in case it has some rows leftover

    end_time = time.time()
    print('elapsed: {}'.format(end_time - start_time))
    print()

inserting scibert__longtoken_mean.csv into emb_scibert_longtokens_mean table...
dumping to db
dumping to db
dumping to db
elapsed: 42.83782410621643

inserting scibert__tokenwise_mean.csv into emb_scibert_tokens_mean table...
dumping to db
dumping to db
dumping to db
elapsed: 44.003576040267944

inserting scibert__cls.csv into emb_scibert_cls table...
dumping to db
dumping to db
dumping to db
elapsed: 40.23398995399475

inserting vanilla__longtoken_mean.csv into emb_bert_longtokens_mean table...
dumping to db
dumping to db
dumping to db
elapsed: 45.219751834869385

inserting vanilla__tokenwise_mean.csv into emb_bert_tokens_mean table...
dumping to db
dumping to db
dumping to db
elapsed: 39.46491026878357

inserting vanilla__cls.csv into emb_bert_cls table...
dumping to db
dumping to db
dumping to db
elapsed: 38.50262999534607

