In [None]:
import pickle
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import copy
import faiss

In [None]:
import torch
import transformers
from transformers import AutoModel, AutoTokenizer#!/usr/bin/env python
# coding: utf-8


pt_model = 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
model = AutoModel.from_pretrained(pt_model)
model.to('cuda')

tokenizer = AutoTokenizer.from_pretrained(pt_model)

In [None]:
print('Loading Strings')

sorted_umls_df = pd.read_csv('/data/Bodenreider_UMLS_DL/Interns/Bernal/sorted_umls2020_auis.csv',sep='\t',index_col=0)

print('Start Encoding')

In [None]:
sorted_umls_df = sorted_umls_df.sort_values('0',ascending=False)

In [None]:
sort = sorted_umls_df

In [None]:
pd.set_option('max_rows',2000)

In [None]:
all_cls = []

with torch.no_grad():
    
    num_strings_proc = 0
    vec_save_batch_num = 0    
    batch_sizes = []
    
    text_batch = []
    pad_size = 0
    
    curr_vecs = 0
    
    for i,row in tqdm(sort.iterrows(),total=len(sort)):
        
        string = str(row['strings'])
        length = row[0]
        
        text_batch.append(string)
        num_strings_proc += 1
        
        if length > pad_size:
            pad_size = length
        
        if pad_size * len(text_batch) > 6000 or num_strings_proc == len(sort):

            if not(os.path.exists('/data/Bodenreider_UMLS_DL/Interns/Bernal/umls2020_sapbert_vecs_{}.p'.format(vec_save_batch_num))):
                text_batch = list(text_batch)
                encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True,max_length=model.config.max_length)
                input_ids = encoding['input_ids']
                attention_mask = encoding['attention_mask']

                input_ids = input_ids.to('cuda')
                attention_mask = attention_mask.to('cuda')

                outputs = model(input_ids, attention_mask=attention_mask)
                all_cls.append(outputs[0][:,0,:].cpu().numpy())
            
            batch_sizes.append(len(text_batch))
            curr_vecs += 1
            
            text_batch = []
            pad_size = 0
            
            if curr_vecs == 100:
                print('Latest_batch_size {}'.format(batch_sizes[-1]))
                print(sum(batch_sizes))
                if not(os.path.exists('/data/Bodenreider_UMLS_DL/Interns/Bernal/umls2020_sapbert_vecs_{}.p'.format(vec_save_batch_num))):
                    all_cls = np.vstack(all_cls)
                    pickle.dump(all_cls, open('/data/Bodenreider_UMLS_DL/Interns/Bernal/umls2020_sapbert_vecs_{}.p'.format(vec_save_batch_num),'wb'))
                
                vec_save_batch_num += 1
                all_cls = []
                curr_vecs = 0
                
    if not(os.path.exists('/data/Bodenreider_UMLS_DL/Interns/Bernal/umls2020_sapbert_vecs_{}.p'.format(vec_save_batch_num))):
        all_cls = np.vstack(all_cls)
        pickle.dump(all_cls, open('/data/Bodenreider_UMLS_DL/Interns/Bernal/umls2020_sapbert_vecs_{}.p'.format(vec_save_batch_num),'wb'))

In [None]:
vecs = []
for i in range(167):
    vecs.append(pickle.load(open('/data/Bodenreider_UMLS_DL/Interns/Bernal/umls2020_sapbert_vecs_{}.p'.format(i),'rb')))

In [None]:
vecs = np.vstack(vecs)

In [None]:
original_umls_2020, new_umls_2020 = pickle.load(open('aui_string_map_UMLS2020_update.p','rb'))

original_auis = set([x[0] for x in original_umls_2020])

In [None]:
synonym_dict = pickle.load(open('new_umls_synonym_aui_dict.p','rb'))

In [None]:
new = []
synonym_list = []

for aui in tqdm(sorted_umls_df.auis):
    
    if aui in original_auis:
        new.append(False)
        synonym_list.append(None)
    else:
        new.append(True)
        synonyms = synonym_dict[aui]
        new_synonyms = []
        
        for aui in synonyms:
            if aui in original_auis:
                new_synonyms.append(aui)
                
        synonym_list.append(new_synonyms)

In [None]:
sorted_umls_df['2020AB?'] = new
sorted_umls_df['2020AA_synonyms'] = synonym_list

In [None]:
sorted_umls_df.groupby('2020AB?').count()

In [None]:
sorted_umls_df['sapbert_vecs'] = list(vecs)

In [None]:
sorted_umls_df

In [None]:
umls2020AA_df = sorted_umls_df[sorted_umls_df['2020AB?'] == False][['0','strings','auis']]
umls2020AA_vecs = sorted_umls_df[sorted_umls_df['2020AB?'] == False].sapbert_vecs
umls2020AA_vecs = np.vstack(umls2020AA_vecs)

In [None]:
umls2020AB_df = sorted_umls_df[sorted_umls_df['2020AB?']][['0','strings','auis','2020AA_synonyms']]
umls2020AB_vecs = sorted_umls_df[sorted_umls_df['2020AB?']].sapbert_vecs
umls2020AB_vecs = np.vstack(umls2020AB_vecs)

In [None]:
import faiss

dim = 768

In [None]:
umls2020AB_vecs.shape

In [None]:
import gc
import subprocess

In [None]:
# np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_vecs_for_queryAA', umls2020AA_vecs)

In [None]:
# np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_vecs_for_queryAB', umls2020AB_vecs)

In [None]:
umls2020AA_vecs = np.load('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_vecs_for_queryAA.npy')
umls2020AB_vecs = np.load('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_vecs_for_queryAB.npy')

In [None]:
index_split = 3
index_chunks = np.array_split(umls2020AA_vecs,index_split)
query_chunks = np.array_split(umls2020AB_vecs,100)

k = 2000

index_chunk_D = []
index_chunk_I = []

current_zero_index = 0

for index_chunk in index_chunks:
    
    index = faiss.IndexFlatL2(dim)   # build the index
        
    if faiss.get_num_gpus() > 1:
        gpu_resources = []

        for i in range(faiss.get_num_gpus()):
            res = faiss.StandardGpuResources()
            gpu_resources.append(res)

        gpu_index = faiss.index_cpu_to_gpu_multiple_py(gpu_resources, index)
    else:
        gpu_resources = faiss.StandardGpuResources()
        gpu_index = faiss.index_cpu_to_gpu(gpu_resources, 0, index)
    
    print(gpu_index.ntotal)
    gpu_index.add(index_chunk)

    D, I = [],[]

    for q in tqdm(query_chunks):
        d,i = gpu_index.search(q, k)

        i += current_zero_index
        
        D.append(d)
        I.append(i)
        
    index_chunk_D.append(D)
    index_chunk_I.append(I)
    
    current_zero_index += len(index_chunk)
    
    print(subprocess.check_output(['nvidia-smi']))

    del gpu_index
    del gpu_resources
    gc.collect()

In [None]:
index_chunk_D

In [None]:
# x = 0
# for d,i in zip(index_chunk_D, index_chunk_I):
#     np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/d_{}'.format(x),np.array(d))
#     np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/i_{}'.format(x),np.array(i))
#     x += 1

In [None]:
stacked_D = []
stacked_I = []

for D,I in zip(index_chunk_D, index_chunk_I):
    
    D = np.vstack(D)
    I = np.vstack(I)
    
    stacked_D.append(D)
    stacked_I.append(I)

In [None]:
# x = 0
# for d,i in zip(stacked_D, stacked_I):
#     np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/stacked_d_{}'.format(x),np.array(d))
#     np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/stacked_i_{}'.format(x),np.array(i))
#     x += 1

In [None]:
# stacked_D = []
# stacked_I = []
# index_split = 3

# for i in tqdm(range(index_split)):
#     D = np.load('/data/Bodenreider_UMLS_DL/Interns/Bernal/stacked_d_{}.npy'.format(i))
#     I = np.load('/data/Bodenreider_UMLS_DL/Interns/Bernal/stacked_i_{}.npy'.format(i))
    
#     stacked_D.append(D)
#     stacked_I.append(I)

In [None]:
stacked_D = np.hstack(stacked_D)
stacked_I = np.hstack(stacked_I)

In [None]:
full_sort_I = []
full_sort_D = []

for d, i in tqdm(zip(stacked_D, stacked_I)):
    
    sort_indices = np.argsort(d)
    
    i = i[sort_indices][:k]
    d = d[sort_indices][:k]
    
    full_sort_I.append(i)
    full_sort_D.append(d)

In [None]:
# np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_2000-NN-indices', np.array(full_sort_I))
# np.save('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_2000-NN-dist', np.array(full_sort_D))

In [None]:
full_sort_I = np.load('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_2000-NN-indices.npy')
full_sort_D = np.load('/data/Bodenreider_UMLS_DL/Interns/Bernal/sapbert_2000-NN-dist.npy')

In [None]:
full_sort_D.shape

In [None]:
umls_2020AA_auis = list(umls2020AA_df.auis)

In [None]:
nearest_neighbors_auis = []

for nn_inds in tqdm(full_sort_I):
    
    nn_auis = [umls_2020AA_auis[i] for i in nn_inds]
    
    nearest_neighbors_auis.append(nn_auis)

In [None]:
umls2020AB_df

In [None]:
query_synonym_auis = list(umls2020AB_df['2020AA_synonyms'])

In [None]:
#Calculating Recall @ 1,5,10,50,100
recall_array = []
closest_dist_true = []
closest_dist_false = []

for true_syn, top100, top100_dist in tqdm(zip(query_synonym_auis, nearest_neighbors_auis, full_sort_D)):
    
    true_syn = set(true_syn)
    
    if len(true_syn) > 0:
        recalls = []

        for n in [1,5,10,50,100,200,500,1000,2000]:

            topn = set(top100[:n])
            true_pos = topn.intersection(true_syn)

            recalls.append(len(true_pos)/len(true_syn))

        recall_array.append(recalls)
        closest_dist_true.append([top100_dist[0], np.mean(top100_dist)])
    else:
        recalls = []

        recall_array.append(recalls)
        closest_dist_false.append([top100_dist[0], np.mean(top100_dist)])

In [None]:
recall_array

In [None]:
pd.DataFrame(recall_array).describe()

In [None]:
umls2020AA_aui2str = {}

for aui, string in tqdm(zip(umls2020AA_df.auis, umls2020AA_df.strings)):
    umls2020AA_aui2str[aui] = string

In [None]:
nearest_neighbors_strings = []

for nn_auis in tqdm(nearest_neighbors_auis):
    nn_strings = [umls2020AA_aui2str[aui] for aui in nn_auis]
    
    nearest_neighbors_strings.append(nn_strings)

In [None]:
synonym_strings = []

for syn_auis in tqdm(umls2020AB_df['2020AA_synonyms']):
    syn_strings = [umls2020AA_aui2str[aui] for aui in syn_auis]
    
    synonym_strings.append(syn_strings)

In [None]:
nearest_neighbors_strings[0]

In [None]:
umls2020AB_df['synonym_strings'] = synonym_strings

In [None]:
umls2020AB_df['sapbert_2000-NN_strings'] = nearest_neighbors_strings
umls2020AB_df['sapbert_2000-NN_auis'] = nearest_neighbors_auis
umls2020AB_df['sapbert_2000-NN_dist'] = list(full_sort_D)

In [None]:
umls2020AB_df['sapbert_2000-NN_recall'] = recall_array

In [None]:
umls2020AB_df['num_syms'] = [len(s) for s in umls2020AB_df['2020AA_synonyms']]

In [None]:
pickle.dump(umls2020AB_df, open('/data/Bodenreider_UMLS_DL/Interns/Bernal/UMLS2020AB_{}-NN_DataFrame.p'.format(k),'wb'))

In [None]:
pd.set_option('max_colwidth',500)

In [None]:
umls2020AB_df[umls2020AB_df['0'] > 100][['strings']]

In [None]:
umls2020AB_df[(umls2020AB_df['0'] < 10) & (umls2020AB_df['num_syms'] > 0)][:100]

In [None]:
del sorted_umls_df
del umls2020AA_vecs
del umls2020AB_vecs

gc.collect()

In [None]:
np.stack(umls2020AB_df[umls2020AB_df.num_syms > 0]['sapbert_400-NN_recall']).mean(axis=0)

In [None]:
np.mean(closest_dist_true,axis=0)

In [None]:
np.mean(closest_dist_false,axis=0)

In [None]:
!nvidia-smi