In [None]:
import pickle
import pandas as pd
from tqdm import tqdm
import numpy as np

In [None]:
pd.set_option('display.max_rows',500)

In [None]:
aui_info = []

with open('/data/Bodenreider_UMLS_DL/UMLS_VERSIONS/2020AB-ACTIVE/META/MRCONSO.RRF','r') as fp:
    
    for line in fp.readlines():
        line = line.split('|')
        cui = line[0]
        aui = line[7]
        string = line[-5]
        
        aui_info.append({'AUI':aui, 'CUI':cui, 'STR':string})
        
cui2sg = {}

with open('/data/Bodenreider_UMLS_DL/UMLS_VERSIONS/2020AB-ACTIVE/META/MRSTY.RRF','r') as fp:
    
    for line in fp.readlines():
        line = line.split('|')
        cui = line[0]
        sg = line[3]
        cui2sg[cui] = sg

In [None]:
cui2aui = {}
aui2cui = {}
aui2sg = {}

cui_sg = []
cui_aui = []

for tup in aui_info:
    aui = tup['AUI']
    cui = tup['CUI']
    sg = cui2sg[cui]
    
    auis = cui2aui.get(cui, [])
    auis.append(aui)
    cui2aui[cui] = auis
    
    aui2cui[aui] = cui
    
    cui_sg.append((cui, sg))
    cui_aui.append((cui, aui))

In [None]:
len(aui2cui)

In [None]:
len(cui2aui)

In [None]:
def load_large_scale(filename, sep):
    tuples = []
    
    with open(filename,'r') as f:
    
        for line in tqdm(f.readlines()):
            sep_line = line.split(sep)
            
            tuples.append(tuple(sep_line))
            
    return tuples

In [None]:
#Loading Official Test Set
official_test = load_large_scale('/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/new_testing_dataset/dedup_test_data.RRF','|')

In [None]:
fps = load_large_scale('/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/error_analysis/FP_with_auis.csv','|')
fns = load_large_scale('/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/error_analysis/FN_with_auis.csv','|')
tps = load_large_scale('/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/error_analysis/TP_with_auis.csv','|')

In [None]:
tns = load_large_scale('/data/Bodenreider_UMLS_DL/thilini/EXPERIMENTS/error_analysis/TN_with_auis.csv','|')

In [None]:
for df in [tps,tns,fps,fns]:
    
    print(df[:5])

In [None]:
official_sorted_test = {}

for ex in tqdm(official_test):
    
    a1 = ex[1]
    a2 = ex[2]
    
    label = int(ex[3].strip())
    
    if a1 > a2:
        sorted_example = [a1, a2]
    else:
        sorted_example = [a2, a1]
        
    sorted_example_key = ' = '.join(sorted_example)
    official_sorted_test[sorted_example_key] = label

In [None]:
len(official_sorted_test)

In [None]:
changed_values = 0

for data in [tps, tns, fps, fns]:
    
    for ex in tqdm(data[1:]):
    
        a1 = ex[0]
        a2 = ex[2]

        #Error Analysis has 0 as positive class (Not sure why)
        label = int(ex[-2].strip()) == 0
        prediction = int(ex[-1].strip()) == 0

        if a1 > a2:
            sorted_example = [a1, a2]
        else:
            sorted_example = [a2, a1]

        sorted_example_key = ' = '.join(sorted_example)

        official_label = official_sorted_test[sorted_example_key]

        if type(official_label) == int:
            assert official_label == label
            official_sorted_test[sorted_example_key] = (label, prediction)
            
            changed_values += 1

In [None]:
changed_values

In [None]:
tps = None
tns = None
fps = None
fns = None

In [None]:
del tns

In [None]:
import gc

In [None]:
gc.collect()

In [None]:
pickle.dump(official_sorted_test, open('official__test_label_preds_dict.p','wb'))

In [None]:
with open('/data/Bodenreider_UMLS_DL/Interns/Bernal/UBERT_Analysis/official_test_predictions.tsv','w') as f:

    for ex in tqdm(official_test):

        a1 = ex[1]
        a2 = ex[2]

        if a1 > a2:
            sorted_example = [a1, a2]
        else:
            sorted_example = [a2, a1]

        sorted_example_key = ' = '.join(sorted_example)

        label, prediction = official_sorted_test[sorted_example_key]

        f.write('\t'.join([a1, a2, str(int(label)), str(int(prediction))]) + '\n')

In [None]:
def same_sg_edge(aui1, aui2):
    
    cui1 = aui2cui[aui1]
    cui2 = aui2cui[aui2]

    sg1 = cui2sg[cui1]
    sg2 = cui2sg[cui2]
    
    if sg1 == sg2:
        return True
    else:
        return False

inter_sg_fps = []
intra_sg_fps = []

present = 0
total_edges = 0

for i, row in tqdm(fps.iterrows()):    
    aui1 = row['aui1']
    aui2 = row['aui2']
    
    total_edges += 1
    
    if aui1 in aui2cui and aui2 in aui2cui:
        
        if same_sg_edge(aui1, aui2):
            intra_sg_fps.append(row)
        else:
            inter_sg_fps.append(row)
            
        present += 1

In [None]:
len(inter_sg_fps)

In [None]:
len(intra_sg_fps)

In [None]:
aui2cui['A24114892']