In [1]:
import _pickle as pickle
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import copy
import glob
import gc
import ipdb
import time

In [2]:
pd.set_option('precision',2)
pd.set_option('display.float_format', lambda x: '%.2f' % x)
pd.set_option('max_rows',200)

In [3]:
umls2020AB_df = pickle.load(open('/data/Bodenreider_UMLS_DL/Interns/Bernal/UMLS2020AB_SAPBERT_Source_Info.p','rb'))

In [4]:
#Stratified Split using CUIs to avoid data leakage to test set

umls2020AB_cui_num_syms_df = umls2020AB_df[['cuis','num_syms']].drop_duplicates()

In [5]:
umls2020AB_cui_num_syms_df['no_syms'] = [n == 0 for n in umls2020AB_cui_num_syms_df.num_syms]

In [6]:
training = []
validation = []
testing = []

val = 0.10
test = 0.20

for i,g in umls2020AB_cui_num_syms_df.groupby('no_syms'):
    
    perm_g = g.sample(len(g),random_state=np.random.RandomState(42)).cuis.values
    
    training.extend(perm_g[:len(g) - int(len(g)*(val + test))])
    validation.extend(perm_g[len(g) - int(len(g)*(val + test)):len(g) - int(len(g)*(test))])
    testing.extend(perm_g[len(g) - int(len(g)*test):])
    
    assert(training[-1] != validation[0])
    assert(validation[-1] != testing[0])        

In [7]:
len(training), len(validation), len(testing)

(204012, 29144, 58288)

In [8]:
training = set(training)
validation = set(validation)
testing = set(testing)

In [9]:
split = []

for cui in umls2020AB_df.cuis:
     
    if cui in training:
        split.append('train')
    elif cui in validation:
        split.append('val')
    elif cui in testing:
        split.append('test')

In [10]:
umls2020AB_df['split'] = split

In [11]:
umls2020AB_df.groupby('split').count()

Unnamed: 0_level_0,0,strings,auis,2020AA_synonyms,synonym_strings,num_syms,sapbert_2000-NN_strings,sapbert_2000-NN_auis,sapbert_2000-NN_dist,sapbert_2000-NN_recall,...,R@10_sapbert_2000-NN_source_syn_cui,R@50_sapbert_2000-NN_source_syn_cui,R@100_sapbert_2000-NN_source_syn_cui,R@200_sapbert_2000-NN_source_syn_cui,R@500_sapbert_2000-NN_source_syn_cui,R@1000_sapbert_2000-NN_source_syn_cui,R@2000_sapbert_2000-NN_source_syn_cui,sources,top10_preferred_strings,num_cuis_per_query_string
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
test,84912,84912,84912,84912,84912,84912,84912,84912,84912,84912,...,32700,32700,32700,32700,32700,32700,32700,84912,84912,84912
train,303361,303359,303361,303361,303361,303361,303361,303361,303361,303361,...,121327,121327,121327,121327,121327,121327,121327,303361,303361,303361
val,41862,41862,41862,41862,41862,41862,41862,41862,41862,41862,...,16050,16050,16050,16050,16050,16050,16050,41862,41862,41862


In [12]:
dedup_df = []

for i,g in tqdm(umls2020AB_df.groupby(['strings','cuis'])):

    for j, row in g.iterrows():
        dedup_df.append(row)
        break

dedup_df = pd.DataFrame(dedup_df)

100%|███████████████████████████████████████████████████████████████████████████████████████▉| 414996/414998 [04:49<00:00, 1431.05it/s]


In [13]:
dedup_df.groupby('split').count()

Unnamed: 0_level_0,0,strings,auis,2020AA_synonyms,synonym_strings,num_syms,sapbert_2000-NN_strings,sapbert_2000-NN_auis,sapbert_2000-NN_dist,sapbert_2000-NN_recall,...,R@10_sapbert_2000-NN_source_syn_cui,R@50_sapbert_2000-NN_source_syn_cui,R@100_sapbert_2000-NN_source_syn_cui,R@200_sapbert_2000-NN_source_syn_cui,R@500_sapbert_2000-NN_source_syn_cui,R@1000_sapbert_2000-NN_source_syn_cui,R@2000_sapbert_2000-NN_source_syn_cui,sources,top10_preferred_strings,num_cuis_per_query_string
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
test,82291,82291,82291,82291,82291,82291,82291,82291,82291,82291,...,31589,31589,31589,31589,31589,31589,31589,82291,82291,82291
train,291896,291896,291896,291896,291896,291896,291896,291896,291896,291896,...,114830,114830,114830,114830,114830,114830,114830,291896,291896,291896
val,40809,40809,40809,40809,40809,40809,40809,40809,40809,40809,...,15627,15627,15627,15627,15627,15627,15627,40809,40809,40809


In [14]:
dedup_df.groupby('split').num_syms.mean()

split
test    4.64
train   7.12
val     4.53
Name: num_syms, dtype: float64

In [15]:
dedup_df.groupby('split').num_syms.sum()

split
test      381422
train    2077928
val       184829
Name: num_syms, dtype: int64

In [16]:
dedup_df.groupby('split').count()/len(dedup_df)

Unnamed: 0_level_0,0,strings,auis,2020AA_synonyms,synonym_strings,num_syms,sapbert_2000-NN_strings,sapbert_2000-NN_auis,sapbert_2000-NN_dist,sapbert_2000-NN_recall,...,R@10_sapbert_2000-NN_source_syn_cui,R@50_sapbert_2000-NN_source_syn_cui,R@100_sapbert_2000-NN_source_syn_cui,R@200_sapbert_2000-NN_source_syn_cui,R@500_sapbert_2000-NN_source_syn_cui,R@1000_sapbert_2000-NN_source_syn_cui,R@2000_sapbert_2000-NN_source_syn_cui,sources,top10_preferred_strings,num_cuis_per_query_string
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
test,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,...,0.08,0.08,0.08,0.08,0.08,0.08,0.08,0.2,0.2,0.2
train,0.7,0.7,0.7,0.7,0.7,0.7,0.7,0.7,0.7,0.7,...,0.28,0.28,0.28,0.28,0.28,0.28,0.28,0.7,0.7,0.7
val,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,...,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.1,0.1,0.1


In [17]:
dedup_df.columns

Index(['0', 'strings', 'auis', '2020AA_synonyms', 'synonym_strings',
       'num_syms', 'sapbert_2000-NN_strings', 'sapbert_2000-NN_auis',
       'sapbert_2000-NN_dist', 'sapbert_2000-NN_recall', 'cuis', 'sem_types',
       'sem_groups', 'source_syns', 'source_syns_plus',
       'sapbert_2000-NN_source_syn_recall',
       'sapbert_2000-NN_source_syn_plus_recall', '2020AA_synonyms_cuis',
       'source_syns_cuis', 'sapbert_2000-NN_cuis',
       'sapbert_2000-NN_cui_recall', 'sapbert_2000-NN_source_syn_cui_recall',
       'number_source_syn_cuis', 'number_source_syn_auis',
       'number_source_syn_plus_auis', 'R@0_sapbert_2000-NN_source_syn_cui',
       'R@1_sapbert_2000-NN_source_syn_cui',
       'R@5_sapbert_2000-NN_source_syn_cui',
       'R@10_sapbert_2000-NN_source_syn_cui',
       'R@50_sapbert_2000-NN_source_syn_cui',
       'R@100_sapbert_2000-NN_source_syn_cui',
       'R@200_sapbert_2000-NN_source_syn_cui',
       'R@500_sapbert_2000-NN_source_syn_cui',
       'R@1000_sapbert_

In [18]:
dedup_df.num_syms.mean(), dedup_df.num_syms.sum()

(6.371577075441691, 2644179)

In [19]:
dedup_df_simple = dedup_df[['auis','strings','cuis', '2020AA_synonyms', 'source_syns', 'source_syns_plus','sapbert_2000-NN_auis', 'sapbert_2000-NN_cuis', 'source_syns_cuis','split']]

In [None]:
# pickle.dump(dedup_df_simple, open('/data/Bodenreider_UMLS_DL/Interns/Bernal/UMLS2020AB_SAPBERT_Source_Info_Official_Split_Basic.p','wb'))

In [27]:
k = 100
gold_candidates = True

aui_splits = {}
cui_splits = {}

for i, row in tqdm(dedup_df_simple.iterrows(), total=len(dedup_df_simple)):
    
    split = row['split']
    
    aui = row['auis']
    cui = row['cuis']
    syns = row['2020AA_synonyms']
    
    aui_samples = aui_splits.get(split, set())
    cui_samples = cui_splits.get(split, set())
    
    candidate_auis = row['sapbert_2000-NN_auis'][:k]
    candidate_cuis = row['sapbert_2000-NN_cuis'][:k]
    
    if gold_candidates:
        for syn in syns:
            aui_samples.add((aui, syn, label))
            
    for aui_cand, cui_cand in zip(candidate_auis, candidate_cuis):
        if cui == cui_cand:
            label = 1
        else:
            label = 0
            
        aui_sample = (aui, aui_cand, label)
        cui_sample = (aui, cui_cand, label)
        
        aui_samples.append(aui_sample)
        cui_samples.append(cui_sample)
        
    aui_splits[split] = set(aui_samples)
    cui_splits[split] = set(cui_samples)

100%|████████████████████████████████████████████████████████████████████████████████████████| 414996/414996 [01:35<00:00, 4358.82it/s]


In [29]:
aui_splits['val']

[('A31737623', 'A29935341', 1),
 ('A31737623', 'A16766008', 1),
 ('A31737623', 'A29950573', 1),
 ('A31737623', 'A30226726', 1),
 ('A31737623', 'A27062244', 1),
 ('A31737623', 'A28574734', 1),
 ('A31737623', 'A31563636', 1),
 ('A31737623', 'A23912498', 1),
 ('A31737623', 'A27067674', 1),
 ('A31737623', 'A22645718', 1),
 ('A31737623', 'A30216775', 1),
 ('A31737623', 'A29774387', 1),
 ('A31737623', 'A29941406', 1),
 ('A31737623', 'A19022300', 1),
 ('A31737623', 'A29944520', 1),
 ('A31737623', 'A27065876', 1),
 ('A31737623', 'A2327896', 0),
 ('A31737623', 'A16759229', 1),
 ('A31737623', 'A29947538', 1),
 ('A31737623', 'A27056745', 1),
 ('A31737623', 'A2331087', 0),
 ('A31737623', 'A31196392', 0),
 ('A31737623', 'A29942173', 0),
 ('A31737623', 'A30295939', 0),
 ('A31737623', 'A27858090', 0),
 ('A31737623', 'A27825813', 0),
 ('A31737623', 'A29729671', 0),
 ('A31737623', 'A1532077', 0),
 ('A31737623', 'A3351516', 0),
 ('A31737623', 'A23973774', 0),
 ('A31737623', 'A16768456', 0),
 ('A31737623

In [None]:
pickle.dump(aui_splits,open('/data/Bodenreider_UMLS_DL/Interns/Bernal/aui_pairwise_data_splits.{}.p'.format(k),'wb'))
pickle.dump(cui_splits,open('/data/Bodenreider_UMLS_DL/Interns/Bernal/cui_pairwise_data_splits.{}.p'.format(k),'wb'))

In [None]:
train = pd.read_csv(train_path, sep='\t', quoting=3)

In [None]:
pickle.dump(dedup_df, open('/data/Bodenreider_UMLS_DL/Interns/Bernal/UMLS2020AB_SAPBERT_Source_Info_Official_Split.p','wb'))

In [None]:
len(umls2020AB_df), len(umls2020AB_df[['strings']].drop_duplicates())

In [None]:
umls = UMLS()

In [None]:
aui_dataset = pickle.load(open('/data/Bodenreider_UMLS_DL/Interns/Bernal/aui_pairwise_data_splits.100.p','rb'))

In [None]:
for split, tups in aui_dataset.items():
    
    one_way = []
    two_way = []
    
    if split == 'val':
        split = 'dev'
        
    for aui1, aui2, label in tqdm(tups):

        str1 = umls.aui2str[aui1]
        str2 = umls.aui2str[aui2]

        one_way.append((str1 + ' [SEP] ' + str2, label))
        if split == 'train':
            two_way.append((str1 + ' [SEP] ' + str2, label))
            two_way.append((str2 + ' [SEP] ' + str1, label))
        
    one_way_df = pd.DataFrame(one_way, columns=['sents','labels'])
    one_way_df = one_way_df.sample(len(one_way_df), random_state=np.random.RandomState(42))
    
    if split == 'train':
        two_way_df = pd.DataFrame(two_way, columns=['sents','labels'])
        two_way_df = two_way_df.sample(len(two_way_df), random_state=np.random.RandomState(42))
    else:
        two_way_df = one_way_df
    
    one_way_df.to_csv('/data/Bodenreider_UMLS_DL/Interns/Bernal/data/RW-UVA-2020AB-one-way/{}.tsv'.format(split), sep='\t', quoting=3)
    two_way_df.to_csv('/data/Bodenreider_UMLS_DL/Interns/Bernal/data/RW-UVA-2020AB-two-way/{}.tsv'.format(split), sep='\t', quoting=3)