In [1]:
import pandas as pd
import itertools
import numpy as np

In [2]:
prediction_dir = '../predictions/'
triplets_dir = '../data/triplets/'
drugbank_dir = '../data/drugbank/'
specification = 'best_pipeline4'

In [3]:
drug_name_map = pd.read_csv(drugbank_dir + 'drug_id_name_map.csv', sep=',', index_col=[0])
food_name_map = pd.read_csv(triplets_dir + 'food_name.tsv', sep='\t', index_col=[0])
food_compound_map = pd.read_csv(triplets_dir + 'compounds_names.tsv', sep='\t', index_col=[0])

In [4]:
drug_ids = drug_name_map.id
drug_names = drug_name_map.drug_name
drug_id_map_dict = dict(zip(drug_ids, drug_names))

food_ids = food_name_map.public_id
food_names = food_name_map.name
food_id_map_dict = dict(zip(food_ids, food_names))

compound_ids = food_compound_map.compound_id
compound_names = food_compound_map.name
compound_id_map_dict = dict(zip(compound_ids, compound_names))

In [5]:
common_drugs = pd.read_csv('../data/common_drugs_num_interactions.csv', sep=';')
common_drugs = common_drugs.dropna()
common_drugs_ids = common_drugs.db_id.values

# print(common_drugs)
print(common_drugs_ids)

['DB00321' 'DB00091' 'DB00564' ... 'DB11237' 'DB00878' 'DB00768']


In [6]:
# common_drugs = common_drugs.sort_values(by=["num_interactions"], ascending=False)
# common_drugs.to_csv('../data/common_drugs_num_interactions.csv', sep=';')

common_drugs

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,drug_name,ATC_WHO,db_id,num_interactions,atc_code
0,976,976,amitriptylin,N06AA09,DB00321,1940.0,N06AA09
1,1048,1048,sandimmun,L04AD01,DB00091,1919.0,L04AD01
2,1419,1419,biston,N03AF01,DB00564,1918.0,N03AF01
3,320,320,neurotop,N03AF01,DB00564,1918.0,N03AF01
4,262,262,digoxin,C01AA05,DB00390,1913.0,C01AA05
...,...,...,...,...,...,...,...
1373,517,517,bepanthen,D08AC52,DB00878,1.0,D08AC52
1374,583,583,loceryl,D01AE16,DB09056,1.0,D01AE16
1375,308,308,wobenzym,M09AB52,DB11237,1.0,M09AB52
1376,46,46,corsodyl,A01AB03,DB00878,1.0,A01AB03


In [34]:
drug_id = common_drugs_ids[0]
prediction_file = prediction_dir + specification + '/complex_' + drug_id + '_negative_' + specification + '.csv'

def get_predictions(prediction_file):
    try:
        predictions = pd.read_csv(prediction_file, sep=',', index_col=[0])
    except:
        return None
    # keep just drug/food/food compound predictions
    predictions['node_type'] = list(itertools.repeat('xxx', predictions.shape[0]))
    predictions.loc[predictions['tail_label'].str.contains("DB\d+", regex=True), 'node_type'] = "drug"
    predictions.loc[predictions['tail_label'].str.contains("FDB"), 'node_type'] = "food_compound"
    predictions.loc[predictions['tail_label'].str.contains("FOOD"), 'node_type'] = "food"
    predictions = predictions.loc[predictions['node_type'] != 'xxx']
    
    return predictions

predictions = get_predictions(prediction_file)
predictions

Unnamed: 0,tail_id,score,tail_label,in_validation,in_testing,node_type
2658,2658,204.137115,DB00732,False,False,drug
3530,3530,176.278229,DB06210,False,False,drug
2290,2290,175.588440,DB00347,False,False,drug
3147,3147,170.297607,DB01245,False,False,drug
4395,4395,165.667740,DB11855,False,False,drug
...,...,...,...,...,...,...
2824,2824,121.586510,DB00908,False,False,drug
4396,4396,121.371811,DB11859,False,False,drug
2521,2521,121.322815,DB00589,False,False,drug
4336,4336,121.278244,DB11633,False,False,drug


In [28]:
predictions.in_testing.sum()

0

In [47]:
# # assign entity names to ids
# for row in predictions.iterrows():
#     tail = row[1].tail_label
#     node_type = row[1].node_type
    
#     if node_type == 'drug':
#         tail_name = drug_id_map_dict[tail]
#     elif node_type == 'food':
#         tail_name = food_id_map_dict[tail]
#     else:
#         tail_name = compound_id_map_dict[tail]
        
#     print(tail, tail_name)    

In [46]:
# # check if the predicted drug/food is in different interaction with the same drug in the training data

# idx = drug_id
# snd_idx = 'DB11166'

# train_triplets = pd.read_csv(triplets_dir + 'train_with_biokg.tsv', sep='\t', index_col=[0])
# valid_triplets = pd.read_csv(triplets_dir + 'valid_with_biokg.tsv', sep='\t', index_col=[0])
# test_triplets = pd.read_csv(triplets_dir + 'test_with_biokg.tsv', sep='\t', index_col=[0])

# filtered_triplets = train_triplets.loc[train_triplets.index == idx]
# in_train = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

# filtered_triplets = valid_triplets.loc[valid_triplets.index == idx]
# in_valid = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

# filtered_triplets = test_triplets.loc[test_triplets.index == idx]
# in_test = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

# print(f'Relation in triplets:')
# print(f'- train:', 'yes' if in_train else 'no')
# print(f'- valid:', 'yes' if in_valid else 'no')
# print(f'- test:', 'yes' if in_test else 'no')
      

### Metrics

In [40]:
# hits@k (how many predicted triplets in the first k positions are in test data)

hits10 = []
hits20 = []
hits30 = []
hits100 = []

for drug in common_drugs_ids:
    prediction_file = prediction_dir + specification + '/complex_' + drug + '_negative_' + specification + '.csv'
    preds = get_predictions(prediction_file)
    
    if preds is None:
        continue
        
    hits10.append(preds.in_testing[:10].sum())
    hits20.append(preds.in_testing[:20].sum())
    hits30.append(preds.in_testing[:30].sum())
    hits100.append(preds.in_testing[:100].sum())
    

print('Avg. hits@10:', np.mean(hits10))
print('Avg. hits@20:', np.mean(hits20))
print('Avg. hits@30:', np.mean(hits30))
print('Avg. hits@100:', np.mean(hits100))

Avg. hits@10: 0.0
Avg. hits@20: 0.0
Avg. hits@30: 0.0
Avg. hits@100: 0.0


In [42]:
# MRR (https://en.wikipedia.org/wiki/Mean_reciprocal_rank)
# if the predicted triplet isn't in the first 100 test data -> index = 100

mrr = []

for drug in common_drugs_ids:
    prediction_file = prediction_dir + specification + '/complex_' + drug + '_negative_' + specification + '.csv'
    preds = get_predictions(prediction_file)
    
    if preds is None:
        continue
    
    try:
        idx = 1 / (list(preds.in_testing).index(True) + 1)
    except:
        idx = 1 / 100
    mrr.append(idx)
    
print('MRR:', np.mean(mrr))

MRR: 0.009999999999999998
