In [1]:
from annoy import AnnoyIndex

In [2]:
from trialstreamer import dbutil
import psycopg2

In [3]:
import numpy as np

In [4]:
cur = dbutil.db.cursor(cursor_factory=psycopg2.extras.DictCursor, name='fetch_large_result')
# cursor name means server side = not loading all to RAM

In [5]:
f = 768
t = AnnoyIndex(f)  # Length of item vector that will be indexed


In [6]:
cur.execute('select pmid, p_v from pubmed_pico;')



In [7]:
import tqdm
int_to_pmid = {}
count = 0

i = 0
batch_size = 1000
while True:
        # consume result over a series of iterations
        # with each iteration fetching a batch of records
        records = cur.fetchmany(size=batch_size)

        if not records:
            break
        
        for r in tqdm.tqdm(records, desc = "iter {} ({} done)".format(i, i*batch_size)):
            vecs = r['p_v']
            if vecs:
                for vec in vecs:
                    int_to_pmid[count] = r['pmid'] 
                    t.add_item(count, vec)
                    count += 1
        i += 1

iter 0 (0 done): 100%|██████████| 1000/1000 [00:00<00:00, 24374.72it/s]
iter 1 (1000 done): 100%|██████████| 1000/1000 [00:00<00:00, 25944.42it/s]
iter 2 (2000 done): 100%|██████████| 1000/1000 [00:00<00:00, 24610.70it/s]
iter 3 (3000 done): 100%|██████████| 1000/1000 [00:00<00:00, 24776.29it/s]
iter 4 (4000 done): 100%|██████████| 1000/1000 [00:00<00:00, 25276.48it/s]
iter 5 (5000 done): 100%|██████████| 1000/1000 [00:00<00:00, 24162.69it/s]
iter 6 (6000 done): 100%|██████████| 1000/1000 [00:00<00:00, 28208.00it/s]
iter 7 (7000 done): 100%|██████████| 1000/1000 [00:00<00:00, 24648.45it/s]
iter 8 (8000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26782.01it/s]
iter 9 (9000 done): 100%|██████████| 1000/1000 [00:00<00:00, 22718.58it/s]
iter 10 (10000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27672.02it/s]
iter 11 (11000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27777.58it/s]
iter 12 (12000 done): 100%|██████████| 1000/1000 [00:00<00:00, 21482.04it/s]
iter 13 (13000 done): 

iter 209 (209000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27258.40it/s]
iter 210 (210000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27288.37it/s]
iter 211 (211000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27979.19it/s]
iter 212 (212000 done): 100%|██████████| 1000/1000 [00:00<00:00, 25687.49it/s]
iter 213 (213000 done): 100%|██████████| 1000/1000 [00:00<00:00, 24626.74it/s]
iter 214 (214000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27715.73it/s]
iter 215 (215000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26488.72it/s]
iter 216 (216000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27963.52it/s]
iter 217 (217000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27515.36it/s]
iter 218 (218000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26493.58it/s]
iter 302 (302000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27999.73it/s]
iter 303 (303000 done): 100%|██████████| 1000/1000 [00:00<00:00, 30233.79it/s]
iter 304 (304000 done): 100%|██████████| 1000/1000 [

iter 498 (498000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26237.69it/s]
iter 499 (499000 done): 100%|██████████| 1000/1000 [00:00<00:00, 29136.01it/s]
iter 500 (500000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26780.64it/s]
iter 501 (501000 done): 100%|██████████| 1000/1000 [00:00<00:00, 31450.75it/s]
iter 502 (502000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27817.75it/s]
iter 503 (503000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27919.40it/s]
iter 504 (504000 done): 100%|██████████| 1000/1000 [00:00<00:00, 27500.21it/s]
iter 505 (505000 done): 100%|██████████| 1000/1000 [00:00<00:00, 28442.52it/s]
iter 506 (506000 done): 100%|██████████| 1000/1000 [00:00<00:00, 24866.63it/s]
iter 507 (507000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26643.87it/s]
iter 508 (508000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26235.55it/s]
iter 509 (509000 done): 100%|██████████| 1000/1000 [00:00<00:00, 26934.74it/s]
iter 510 (510000 done): 100%|██████████| 1000/1000 [

In [8]:
t.build(10)

True

In [9]:
from bert_serving.client import BertClient

In [10]:
bc = BertClient()
q = bc.encode(['type II diabetes'])

In [22]:
q = bc.encode(['adults with primary hypertension', 'acupuncture'])

In [23]:
# top 5 most similar...
annoy_idx = t.get_nns_by_vector(q[0], n=5000, search_k=500000)
pmids = [int_to_pmid[i] for i in annoy_idx]

In [21]:
print('most similar PMIDs')
print(pmids)

most similar PMIDs
['26380159', '25492835', '30112195', '29493562', '23404465', '6848468', '30688979', '26817716', '15824294', '28864691', '12534445', '20137542', '9331017', '1720476', '12409967', '30541254', '30541254', '11833827', '25880068', '29778251', '20016796', '27640987', '27640987', '23680334', '29030456', '30713028', '11497204', '11281235', '11078175', '8613267', '15820487', '15149889', '22853848', '26791478', '30715100', '17617280', '29122812', '28968870', '22814178', '24635770', '10535718', '11726010', '18754021', '18021136', '26890438', '26915709', '21354549', '10561625', '6360868', '1380609', '2412056', '7029318', '6405876', '19609054', '7196467', '8217026', '2550243', '8364940', '8340155', '8466742', '1974505', '3230151', '3330989', '3519243', '3561160', '6151343', '10605328', '773494', '10072988', '2634218', '10361448', '7511751', '11693466', '11693466', '14732734', '14732734', '2349138', '8892775', '20021921', '27974526', '6386484', '24200748', '8210563', '22027785', '