In [1]:
import numpy as np
import csv
import os
import time
import json
import mysql.connector as mysql

In [3]:
#  pure mysql python connector approach to the mysql connection
#   offers easy control over the BLOB type embeddings
client_config = {'unix_socket':'/home/brendanchambers/.sql.sock',
                            'database':'test_pubmed',
                            'use_pure':True}  # pure python mode is important for reading blobs of bytes
write_size = 10000  # write to db every 10K rows
table_name = 'scibert_mean_embedding'  # created in mysql terminal client
# '''CREATE TABLE {}             # <- like this
#                (pmid int NOT NULL,
#                embedding BLOB NOT NULL,
#                PRIMARY KEY (pmid))'''.format(table_name)

path2embeddings = '/project2/jevans/brendan/pubmed_data_processing/scibert_embedding_chunks/'


## helper function

In [4]:
def write_data_to_db(entries, table_name):
    db = mysql.connect(**client_config)
    sql = '''INSERT INTO {} (pmid,embedding)
             VALUES (%s, %s)
             ON DUPLICATE KEY UPDATE
             pmid=values(pmid), embedding=values(embedding)'''.format(table_name)
    cursor = db.cursor()
    cursor.executemany(sql, entries)
    cursor.close()
    db.commit()

## read csv output files and write to db

In [None]:
start_time = time.time()

for chunk_filename in os.listdir(path2embeddings):
    data = []
    print('processing {}...'.format(chunk_filename))
    
    with open(path2embeddings + chunk_filename) as f:
        csvreader = csv.reader(f, delimiter=' ')
        # note: there is no header in these chunk files

        for idx, row in enumerate(csvreader):
            pmid = int(row[0])
            
            embedding_blob = np.array(json.loads(row[1])).tobytes(
                                                        order='C')
            data.append((pmid, embedding_blob))
            
            if len(data) > write_size:  # write to db intermittently
                write_data_to_db(data, table_name)
                data = []           
                # reset data
                
    write_data_to_db(data, table_name)  # empty the buffer at the end

end_time = time.time()
print('elapsed: {}'.format(end_time - start_time))
