In [2]:
import pickle
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import torch
import transformers
from transformers import AutoModel, BertTokenizer, AutoTokenizer, AutoModelForNextSentencePrediction

import sys

In [3]:
tokenizer_filename = '/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/aui_vec/umls-vocab.txt'
pt_model = '/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/1_UMLS_ONLY/train_sp/out_all_correct_metric_from_32/checkpoint-290020_2/'

In [20]:
tokenizer_filename =  'cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
pt_model = "/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/6_SAPBERT_SP/train/out_from_sapbert_from_3/checkpoint-986068/"

In [5]:
try:
    tokenizer = BertTokenizer(tokenizer_filename)
except:
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_filename)
    
model = AutoModelForNextSentencePrediction.from_pretrained(pt_model)
# model.to('cuda')

In [None]:
print('Loading Strings')

real_world_pairs = pickle.load(open('/data/Bodenreider_UMLS_DL/Interns/Bernal/uva_test_edges.p','rb'))

print('Start Classifying')



In [None]:
real_world_pairs = np.random.RandomState(42).permutation(real_world_pairs)
testing_data_df = pd.DataFrame(real_world_pairs)

validation_set = []
testing_set = []

for i,g in testing_data_df.groupby(2):
    
    perm = g.sample(len(g), random_state=np.random.RandomState(42))
    
    val = perm[:int(len(g)*0.01)]
    test = perm[int(len(g)*0.01):]
    
    validation_set.append(val)
    testing_set.append(test)

validation_set = pd.concat(validation_set)
testing_set = pd.concat(testing_set)

subset = [(h,t,l) for h,t,l in zip(validation_set[0],validation_set[1],validation_set[2])]

In [6]:
print('Loading Strings')

subset = pickle.load(open('/data/Bodenreider_UMLS_DL/Interns/Bernal/uva_test_subset.p','rb'))

print('Start Classifying')

Loading Strings
Start Classifying


In [7]:
transformers.logging.set_verbosity_error()

In [8]:
len(subset)

17199

In [9]:
all_cls = []
    
with torch.no_grad():
    
    num_strings_proc = 0
    vec_save_batch_num = 0    
    batch_sizes = []
    
    text_batch = []
    pad_size = 0
        
    for head, tail, syn, in tqdm(subset):
    
        try:
            forward = head + ' [SEP] ' + tail + ' [SEP] '
            backward = tail + ' [SEP] ' + head + ' [SEP] '
        except:
            forward = head + tail
            backward = head + tail
                
        length = max(len(forward),len(backward))/3
        
        text_batch.append((head,tail))
        text_batch.append((tail,head))
        
        num_strings_proc += 1
        
        if length > pad_size:
            pad_size = length
        
        if pad_size * len(text_batch) > 6000 or num_strings_proc == len(subset):

            text_batch = list(text_batch)
            
            encoding = tokenizer.batch_encode_plus(text_batch, return_tensors='pt', padding=True, truncation=True, max_length=512)
            
            input_ids = encoding['input_ids']
            token_type_ids = encoding['token_type_ids']
            attention_mask = encoding['attention_mask']

#             input_ids = input_ids.to('cuda')
#             token_type_ids = token_type_ids.to('cuda')
#             attention_mask = attention_mask.to('cuda')

            outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
            all_cls.append(outputs[0].cpu().numpy())
                            
            batch_sizes.append(len(text_batch))
            text_batch = []
            
            pad_size = 0
                
    all_cls = np.vstack(all_cls)

100%|████████████████████████████████████████████████████████████████████████████████████████████| 17199/17199 [39:40<00:00,  7.22it/s]


In [10]:
all_cls = all_cls[:sum(batch_sizes)]

In [11]:
sum(batch_sizes),len(subset)*2,len(all_cls)

(34398, 34398, 34398)

In [12]:
forward_preds = []
backward_preds = []

for i,p in enumerate(all_cls):
    if i % 2 == 0:
        forward_preds.append(p)
    else:
        backward_preds.append(p)

In [13]:
all_cls.shape

(34398, 2)

In [14]:
len(subset)

17199

In [15]:
edges_to_test_df = pd.DataFrame(subset)

In [16]:
edges_to_test_df['forward_pred'] = np.argmax(np.vstack(forward_preds),axis=1)
edges_to_test_df['backward_pred'] = np.argmax(np.vstack(backward_preds),axis=1)

edges_to_test_df['forward_pred'] = [e for e in edges_to_test_df['forward_pred']]
edges_to_test_df['backward_pred'] = [e for e in edges_to_test_df['backward_pred']]

In [17]:
edges_to_test_df['f_probs'] = forward_preds
edges_to_test_df['b_probs'] = backward_preds

In [18]:
edges_to_test_df

Unnamed: 0,0,1,2,forward_pred,backward_pred,f_probs,b_probs
0,OXYGEN 99 L in 100 L RESPIRATORY (INHALATION) GAS,OXYGEN 990 mL in 1 L RESPIRATORY (INHALATION) GAS,0,0,0,"[7.168312, -7.668907]","[7.1089563, -7.727689]"
1,chlorambucil,"para-N,N-di(b-chloroethyl)aminophenylbutyric acid",0,0,0,"[6.863706, -7.969083]","[3.7867756, -3.911776]"
2,Aspirin 325 mg in 325 mg ORAL TABLET [Value Ph...,"ASPIRIN 325 mg ORAL TABLET, COATED [aspirin pa...",0,0,0,"[8.269983, -5.4136357]","[7.5090084, -7.3260393]"
3,"SPINOCEREBELLAR ATAXIA, X-LINKED 3","Spinocerebellar ataxia, X-linked, 3",0,1,1,"[-0.08090535, 0.10617712]","[-0.09343867, 0.1295889]"
4,toldimfos sodium 200 MG/ML Injectable Solution...,TOLDIMFOS SODIUM 200 mg in 1 mL INTRAMUSCULAR ...,0,0,0,"[7.3894987, -7.448188]","[2.0362377, -2.3767476]"
...,...,...,...,...,...,...,...
17194,eprinomectin,eprinomectin 50 MG/ML,1,1,1,"[-2.8573313, 11.776593]","[-2.7095888, 11.607813]"
17195,Regurgitation,mitral regurgitation due to acute myocardial i...,1,1,1,"[-2.8647547, 11.786915]","[-2.8571477, 11.776012]"
17196,Reticular dysgenesia,Nervous system of pectoral girdle,1,1,1,"[-2.8814619, 11.813187]","[-2.8857632, 11.821387]"
17197,COQ9 gene,Genes,1,1,1,"[-2.875341, 11.802918]","[-2.8411045, 11.755026]"


In [21]:
from sklearn import metrics

In [19]:
labels = [int(e) for e in edges_to_test_df[2]]

In [22]:
f_p, f_r, forward_ubert_f1, f_s = metrics.precision_recall_fscore_support(labels, edges_to_test_df['forward_pred'], pos_label=0, average='binary')
b_p, b_r, backward_ubert_f1, b_s = metrics.precision_recall_fscore_support(labels, edges_to_test_df['backward_pred'], pos_label=0, average='binary')

In [23]:
forward_ubert_f1, f_p, f_r

(0.93015332197615, 0.8863636363636364, 0.978494623655914)

In [28]:
aui_string = 'OXYGEN 99 L in 100 L RESPIRATORY (INHALATION) GAS'
aui_tokens = tokenizer.tokenize(aui_string)
tokenizer.convert_tokens_to_ids(aui_tokens)

[3208, 2883, 54, 194, 549, 54, 1351, 12, 3617, 13, 1482]

In [29]:
tokenizer.decode([3208, 2883, 54, 194, 549, 54, 1351, 12, 3617, 13, 1482])

'oxygen 99 l in 100 l respiratory ( inhalation ) gas'