In [16]:
import pandas as pd
import itertools
import numpy as np

In [2]:
prediction_dir = '../predictions/'
triplets_dir = '../data/triplets/'
drugbank_dir = '../data/drugbank/'

In [3]:
drug_name_map = pd.read_csv(drugbank_dir + 'drug_id_name_map.csv', sep=',', index_col=[0])
food_name_map = pd.read_csv(triplets_dir + 'food_name.tsv', sep='\t', index_col=[0])
food_compound_map = pd.read_csv(triplets_dir + 'compounds_names.tsv', sep='\t', index_col=[0])

In [4]:
drug_ids = drug_name_map.id
drug_names = drug_name_map.drug_name
drug_id_map_dict = dict(zip(drug_ids, drug_names))

food_ids = food_name_map.public_id
food_names = food_name_map.name
food_id_map_dict = dict(zip(food_ids, food_names))

compound_ids = food_compound_map.compound_id
compound_names = food_compound_map.name
compound_id_map_dict = dict(zip(compound_ids, compound_names))

In [5]:
common_drugs = pd.read_csv('../data/common_drugs.csv', sep=';')
common_drugs_ids = common_drugs.DrugBank_id.values
print(common_drugs)
common_drugs_ids

                    drug  DrugBank_name DrugBank_id
0               lexaurin     Bromazepam     DB01558
1            fraxiparine     Nadroparin     DB08813
2               novalgin     Metamizole     DB04817
3              dithiaden      Diltiazem     DB00343
4               diazepam       Diazepam     DB00829
5              tamoxifen      Tamoxifen     DB00675
6   paracetamol(paralen)  Acetaminophen     DB00316
7              metamizol     Metamizole     DB04817
8              neurontin     Gabapentin     DB00996
9                ketonal      Ketoprofe     DB01009
10               ibalgin      Ibuprofen     DB01050
11               calcium        Calcium     DB01373
12              euthyrox  Levothyroxine     DB00451
13            filgrastim     Filgrastim     DB00099
14                ananas      pineapple   FOOD00012
15                   mak          poppy   FOOD00127
16               merunka       appricot   FOOD00144
17                  grep     grapefruit   FOOD00256
18          

array(['DB01558', 'DB08813', 'DB04817', 'DB00343', 'DB00829', 'DB00675',
       'DB00316', 'DB04817', 'DB00996', 'DB01009', 'DB01050', 'DB01373',
       'DB00451', 'DB00099', 'FOOD00012', 'FOOD00127', 'FOOD00144',
       'FOOD00256', 'FOOD00206', 'FOOD00178'], dtype=object)

In [18]:
drug_id = common_drugs_ids[1]

def get_predictions(drug_id):
    try:
        predictions = pd.read_csv(prediction_dir + 'complex_' + drug_id + '_interacts_ddi.csv', sep=',', index_col=[0])
    except:
        return None
    # keep just drug/food/food compound predictions
    predictions['node_type'] = list(itertools.repeat('xxx', predictions.shape[0]))
    predictions.loc[predictions['tail_label'].str.contains("DB\d+", regex=True), 'node_type'] = "drug"
    predictions.loc[predictions['tail_label'].str.contains("FDB"), 'node_type'] = "food_compound"
    predictions.loc[predictions['tail_label'].str.contains("FOOD"), 'node_type'] = "food"
    predictions = predictions.loc[predictions['node_type'] != 'xxx']
    
    return predictions

get_predictions(drug_id)

Unnamed: 0,tail_id,score,tail_label,in_validation,in_testing,node_type
1991,1991,8.604259,DB11166,False,False,drug
459,459,8.298801,DB00552,False,True,drug
584,584,8.085953,DB00686,False,False,drug
570,570,7.685174,DB00671,False,False,drug
852,852,7.239124,DB00974,True,False,drug
...,...,...,...,...,...,...
603,603,3.608001,DB00706,False,False,drug
322,322,3.586452,DB00398,False,False,drug
1098,1098,3.540261,DB01242,False,False,drug
28,28,3.531218,DB00033,True,False,drug


In [7]:
print(predictions.in_validation.sum())
print(predictions.in_testing.sum())

7
2


In [None]:
# assign entity names to ids
for row in predictions.iterrows():
    tail = row[1].tail_label
    node_type = row[1].node_type
    
    if node_type == 'drug':
        tail_name = drug_id_map_dict[tail]
    elif node_type == 'food':
        tail_name = food_id_map_dict[tail]
    else:
        tail_name = compound_id_map_dict[tail]
        
    print(tail, tail_name)    

In [9]:
# check if the predicted drug/food is in different interaction with the same drug in the training data

# idx = drug_id
# snd_idx = 'DB11166'

# train_triplets = pd.read_csv(triplets_dir + 'train.tsv', sep='\t', index_col=[0])
# valid_triplets = pd.read_csv(triplets_dir + 'valid.tsv', sep='\t', index_col=[0])
# test_triplets = pd.read_csv(triplets_dir + 'test.tsv', sep='\t', index_col=[0])

# filtered_triplets = train_triplets.loc[train_triplets.index == idx]
# in_train = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

# filtered_triplets = valid_triplets.loc[valid_triplets.index == idx]
# in_valid = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

# filtered_triplets = test_triplets.loc[test_triplets.index == idx]
# in_test = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

# print(f'Relation in triplets:')
# print(f'- train:', 'yes' if in_train else 'no')
# print(f'- valid:', 'yes' if in_valid else 'no')
# print(f'- test:', 'yes' if in_test else 'no')
      

Relation in triplets:
- train: no
- valid: no
- test: yes


### Metrics

In [29]:
# hits@k (how many predicted triplets in the first k positions are in test data)

hits10 = []
hits20 = []
hits30 = []
hits100 = []

for drug in common_drugs_ids:
    preds = get_predictions(drug)
    
    if preds is None:
        continue
        
    hits10.append(preds.in_testing[:10].sum())
    hits20.append(preds.in_testing[:20].sum())
    hits30.append(preds.in_testing[:30].sum())
    hits100.append(preds.in_testing[:100].sum())
    

print('Avg. hits@10:', np.mean(hits10))
print('Avg. hits@20:', np.mean(hits20))
print('Avg. hits@30:', np.mean(hits30))
print('Avg. hits@100:', np.mean(hits100))

Avg. hits@10: 0.5
Avg. hits@20: 0.75
Avg. hits@30: 0.8125
Avg. hits@100: 3.5


In [37]:
# MRR (https://en.wikipedia.org/wiki/Mean_reciprocal_rank)
# if the predicted triplet isn't in the first 100 test data -> index = 100

mrr = []

for drug in common_drugs_ids:
    preds = get_predictions(drug)
    
    if preds is None:
        continue
    
    try:
        idx = list(preds.in_testing).index(True) + 1
        mrr.append(idx)
    except:
        idx = 100
    
print('MRR:', np.mean(mrr))

MRR: 22.3


In [34]:
l = [False, False]
l.index(True, 0)

ValueError: True is not in list

### All common drugs

In [10]:
# common_drugs_all = pd.read_csv('../data/common_drugs_num_interactions.csv', sep=';', index_col=[0])
# common_drugs_all = common_drugs_all[['drug_name', 'db_id']].dropna()
# common_drugs_all