In [1]:
import pandas as pd
import itertools
import numpy as np
import json

In [101]:
prediction_dir = '../predictions/'
triplets_dir = '../data/triplets/'
drugbank_dir = '../data/drugbank/'
specification = 'best_pipeline4'
data_dir = '/interactions/'
food_compounds_names_path = '../data/triplets/compounds_names.tsv'
model_name = 'rotate_'

### Create id - name mappings

In [7]:
drug_name_map = pd.read_csv(drugbank_dir + 'drug_id_name_map.csv', sep=',', index_col=[0])
food_name_map = pd.read_csv(triplets_dir + 'food_name.tsv', sep='\t', index_col=[0])
food_compound_map = pd.read_csv(triplets_dir + 'compounds_names.tsv', sep='\t', index_col=[0])

In [8]:
drug_ids = drug_name_map.id
drug_names = drug_name_map.drug_name
drug_id_map_dict = dict(zip(drug_ids, drug_names))

food_ids = food_name_map.public_id
food_names = food_name_map.name
food_id_map_dict = dict(zip(food_ids, food_names))

compound_ids = food_compound_map.compound_id
compound_names = food_compound_map.name
compound_id_map_dict = dict(zip(compound_ids, compound_names))

In [9]:
food_id_map_dict

{'FOOD00001': 'Angelica',
 'FOOD00002': 'Savoy cabbage',
 'FOOD00003': 'Silver linden',
 'FOOD00004': 'Kiwi',
 'FOOD00005': 'Allium',
 'FOOD00006': 'Garden onion',
 'FOOD00007': 'Leek',
 'FOOD00008': 'Garlic',
 'FOOD00009': 'Chives',
 'FOOD00010': 'Lemon verbena',
 'FOOD00011': 'Cashew nut',
 'FOOD00012': 'Pineapple',
 'FOOD00013': 'Dill',
 'FOOD00014': 'Custard apple',
 'FOOD00015': 'Wild celery',
 'FOOD00016': 'Peanut',
 'FOOD00017': 'Burdock',
 'FOOD00018': 'Horseradish',
 'FOOD00019': 'Tarragon',
 'FOOD00020': 'Mugwort',
 'FOOD00021': 'Asparagus',
 'FOOD00022': 'Oat',
 'FOOD00023': 'Star fruit',
 'FOOD00024': 'Brazil nut',
 'FOOD00025': 'Common beet',
 'FOOD00026': 'Borage',
 'FOOD00027': 'Chinese mustard',
 'FOOD00028': 'Swede',
 'FOOD00029': 'Rape',
 'FOOD00030': 'Common cabbage',
 'FOOD00031': 'Cauliflower',
 'FOOD00032': 'Brussel sprouts',
 'FOOD00033': 'Kohlrabi',
 'FOOD00034': 'Broccoli',
 'FOOD00035': 'Chinese cabbage',
 'FOOD00036': 'Turnip',
 'FOOD00037': 'Pigeon pea',
 'F

### Load drugs and foods 

In [10]:
# common_drugs = pd.read_csv('../data/common_drugs_num_interactions.csv', sep=';')
common_drugs = pd.read_csv('../data/drugs4prediction.csv', sep=';')
common_drugs = common_drugs.dropna()
common_drugs_ids = common_drugs.db_id.values

common_drugs_ids = list(set(common_drugs_ids))
common_drugs_ids[:10]

['DB01217',
 'DB00958',
 'DB00661',
 'DB01229',
 'DB00642',
 'DB01006  ',
 'DB00682',
 'DB00563',
 'DB00321',
 'DB00441']

In [11]:
with open('../data/foods4predictions-2.txt', 'r') as f:
    foods = f.readlines()

foods = [food.strip() for food in foods]
foods[:10]

['FOOD00055',
 'FOOD00255',
 'FOOD00256',
 "St. John's Wort",
 'FOOD00181',
 'Dandelion',
 "Cat's claw",
 'Nettle',
 'Cordyceps',
 'Reishi Mushroom']

In [12]:
food_id_map_dict[foods[4]]

'Dandelion'

### Load predictions for specific drug/food

In [13]:
def get_predictions(prediction_file):
    try:
        predictions = pd.read_csv(prediction_file, sep=',', index_col=[0])
    except:
        return None
    # keep just drug/food/food compound predictions
    predictions['node_type'] = list(itertools.repeat('xxx', predictions.shape[0]))
    predictions.loc[predictions['tail_label'].str.contains("DB\d+", regex=True), 'node_type'] = "drug"
    predictions.loc[predictions['tail_label'].str.contains("FDB"), 'node_type'] = "food_compound"
    predictions.loc[predictions['tail_label'].str.contains("FOOD"), 'node_type'] = "food"
    predictions = predictions.loc[predictions['node_type'] != 'xxx']
    
    return predictions

# assign entity names to ids
def assign_names(predictions, show=False):
    tail_names = []
    for row in predictions.iterrows():
        tail = row[1].tail_label
        node_type = row[1].node_type

        if node_type == 'drug':
            tail_name = drug_id_map_dict.get(tail, tail)
        elif node_type == 'food':
            tail_name = food_id_map_dict.get(tail, tail)
        else:
            tail_name = compound_id_map_dict.get(tail, tail)

        if show:
            print(tail, tail_name)    
        tail_names.append(tail_name)
    return tail_names

In [14]:
# drug predictions
drug_id = common_drugs_ids[9]
prediction_file = prediction_dir + specification + data_dir + model_name + drug_id + '_interacts_' + specification + '.csv'

print('Interactions with', drug_id_map_dict[drug_id])

predictions = get_predictions(prediction_file)
print(predictions.head(10))
print()
_ = assign_names(predictions, show=True)

Interactions with Gemcitabine
      tail_id       score tail_label  in_validation  in_testing node_type
3495     3495 -104.338692    DB00322          False       False      drug
3449     3449 -104.818581    DB00276          False       False      drug
3909     3909 -105.393562    DB00742          False       False      drug
4404     4404 -105.414932    DB01240          False       False      drug
3450     3450 -105.444244    DB00277          False       False      drug
3264     3264 -105.457451    DB00087          False       False      drug
3511     3511 -105.514542    DB00339          False       False      drug
3650     3650 -105.556931    DB00479          False       False      drug
3198     3198 -105.899529    DB00015          False       False      drug
3415     3415 -105.903229    DB00242          False       False      drug

DB00322 Floxuridine
DB00276 Amsacrine
DB00742 Mannitol
DB01240 Epoprostenol
DB00277 Theophylline
DB00087 Alemtuzumab
DB00339 Pyrazinamide
DB00479 Amikacin


In [15]:
tails = predictions.tail_label.values

In [16]:
# food predicitons
idx=2
food_id = foods[idx]
prediction_file = prediction_dir + specification + data_dir + model_name + food_id + '_interacts_' + specification + '.csv'

food_name = food_id
if "FOOD" in food_id:
    food_name = food_id_map_dict[food_id]
print('Interactions with', food_name)

predictions = get_predictions(prediction_file)
print(predictions.head())
print()
assign_names(predictions)

Interactions with Grapefruit
       tail_id       score tail_label  in_validation  in_testing node_type
3421      3421 -107.578560    DB00248          False       False      drug
10868    10868 -107.997795    DB14568          False       False      drug
3439      3439 -108.089890    DB00266          False       False      drug
3739      3739 -108.356354    DB00570          False       False      drug
4387      4387 -108.673676    DB01223          False       False      drug



['Cabergoline',
 'Ivosidenib',
 'Dicoumarol',
 'Vinblastine',
 'Aminophylline',
 'Cisapride',
 'Diazepam',
 'Zidovudine',
 'Tamoxifen',
 'Brivaracetam',
 'Imatinib',
 'Tenecteplase',
 'Propranolol',
 'Telmisartan',
 'Siltuximab',
 'Doxazosin',
 'Exemestane',
 'Clobazam',
 'Cefepime',
 'Lumiracoxib',
 'Vindesine',
 'Methotrexate',
 'Albendazole',
 'Cilostazol',
 'Erlotinib',
 'Ibuprofen',
 'Tolmetin',
 'Oxyphenbutazone',
 'Alteplase',
 'Bicalutamide',
 'Heparin',
 'Vinorelbine',
 'Magnesium',
 'Sirolimus',
 'Alprazolam',
 'Cyproterone acetate',
 'Flumethasone',
 'Clopidogrel',
 'Pentoxifylline',
 'Pantoprazole',
 'Triamcinolone',
 'Dabrafenib',
 'Romidepsin',
 'Cephalexin',
 'Digitoxin',
 'Flucytosine',
 'Antithrombin III human',
 'Sildenafil',
 'Theophylline',
 'Quinidine',
 'Dyphylline',
 'Flurbiprofen',
 'Etodolac',
 'Docetaxel',
 'Terbinafine',
 'Bortezomib',
 'Bexarotene',
 'Cefadroxil',
 'Bupropion',
 'Irinotecan',
 'Urokinase',
 'Clonazepam',
 'Oxycodone',
 'Zileuton',
 'Bacitrac

In [83]:
# # check if the predicted drug/food is in different interaction with the same drug in the training data
def check_known_triplets(idx, snd_idx, data):

    train_triplets = pd.read_csv(triplets_dir + 'train_' + data + '.tsv', sep='\t', index_col=[0])
    valid_triplets = pd.read_csv(triplets_dir + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
    test_triplets = pd.read_csv(triplets_dir + 'test_' + data + '.tsv', sep='\t', index_col=[0])

    filtered_triplets = train_triplets.loc[train_triplets.index == idx]
    in_train = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

    filtered_triplets = valid_triplets.loc[valid_triplets.index == idx]
    # print(filtered_triplets.loc[filtered_triplets['tail'] == snd_idx])
    in_valid = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

    filtered_triplets = test_triplets.loc[test_triplets.index == idx]
    in_test = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

#     print(f'Relation in triplets:')
#     print(f'- train:', 'yes' if in_train else 'no')
#     print(f'- valid:', 'yes' if in_valid else 'no')
#     print(f'- test:', 'yes' if in_test else 'no')

    in_triplets = in_train or in_valid or in_test

    # find also symetric relations
    filtered_triplets = train_triplets.loc[train_triplets.index == snd_idx]
    in_train = filtered_triplets.loc[filtered_triplets['tail'] == idx].any().sum()
    # print(filtered_triplets.loc[filtered_triplets['tail'] == idx])

    filtered_triplets = valid_triplets.loc[valid_triplets.index == snd_idx]
    in_valid = filtered_triplets.loc[filtered_triplets['tail'] == idx].any().sum()

    filtered_triplets = test_triplets.loc[test_triplets.index == snd_idx]
    in_test = filtered_triplets.loc[filtered_triplets['tail'] == idx].any().sum()

#     print(f'Symetric relation in triplets:')
#     print(f'- train:', 'yes' if in_train else 'no')
#     print(f'- valid:', 'yes' if in_valid else 'no')
#     print(f'- test:', 'yes' if in_test else 'no')
    
    if in_triplets or in_train or in_valid or in_test:
        return True
    
    return False
    


In [84]:
data = 'drugbank'
idx = 'DB01217'
snd_idx = 'DB01062'
print('Prediction:', idx, 'interacts with', snd_idx)
check_known_triplets(idx, snd_idx, data)

Prediction: DB01217 interacts with DB01062


True

### Metrics

In [21]:
# https://pykeen.readthedocs.io/en/stable/api/pykeen.metrics.ranking.HitsAtK.html
# The number of triplets that are relevant in predictions in the top K positions * 1/total number of predictions (=100).
def compute_hits_at_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir, k=100):
    hits_k_valid = []
    hits_k_test = []

    data = data_dir.replace('/', '')
    valid_triplets = pd.read_csv(triplets_dir + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
    test_triplets = pd.read_csv(triplets_dir + 'test_' + data + '.tsv', sep='\t', index_col=[0])

    for drug in common_drugs_ids:
        prediction_file = prediction_dir + specification + data_dir + model_name + drug + '_interacts_' + specification + '.csv'
        preds = get_predictions(prediction_file)
        
        if preds is None:
    #         print(drug)
            continue

        preds_count = preds.shape[0]
        
        # we are intersested only in first k predictions
        preds = preds.head(k)

        in_valid_count = 0
        in_test_count = 0

        filtered_triplets_valid = valid_triplets.loc[valid_triplets.index == drug]
        filtered_triplets_test = test_triplets.loc[test_triplets.index == drug]
        tails = preds.tail_label.values

        # count triplets (drug, interacts, tail_id) that are in validation/testing dataset
        for tail_id in tails:
            if filtered_triplets_valid.loc[filtered_triplets_valid['tail'] == tail_id].any().sum() > 0:
                in_valid_count += 1
            if filtered_triplets_test.loc[filtered_triplets_test['tail'] == tail_id].any().sum() > 0:
                in_test_count += 1

        hits_k_valid.append(in_valid_count/preds_count)
        hits_k_test.append(in_test_count/preds_count)


    print(f'Avg. hits@{k} on validation dataset: {np.mean(hits_k_valid)}')
    print(f'Avg. hits@{k} on test dataset: {np.mean(hits_k_test)}')

In [28]:
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir, 1)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir, 5)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir, 10)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir, 30)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir, 100)

Avg. hits@1 on validation dataset: 0.003341750841750842
Avg. hits@1 on test dataset: 0.006683501683501684
Avg. hits@5 on validation dataset: 0.019217171717171716
Avg. hits@5 on test dataset: 0.024242424242424242
Avg. hits@10 on validation dataset: 0.04175925925925927
Avg. hits@10 on test dataset: 0.038468013468013464
Avg. hits@30 on validation dataset: 0.10861111111111112
Avg. hits@30 on test dataset: 0.10195286195286195
Avg. hits@100 on validation dataset: 0.22052188552188554
Avg. hits@100 on test dataset: 0.207962962962963


In [29]:
# MRR (https://en.wikipedia.org/wiki/Mean_reciprocal_rank)
# if the predicted triplet isn't in the first 100 test data -> index = 100
def compute_mrr(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir):
    data = data_dir.replace('/', '')
    valid_triplets = pd.read_csv(triplets_dir + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
    test_triplets = pd.read_csv(triplets_dir + 'test_' + data + '.tsv', sep='\t', index_col=[0])
    
    mrr_valid = []
    mrr_test = []

    for drug in common_drugs_ids:
        prediction_file = prediction_dir + specification + data_dir + model_name + drug + '_interacts_' + specification + '.csv'
        preds = get_predictions(prediction_file)
        
        if preds is None:
    #         print(drug)
            continue

        preds = preds.head(100)

        filtered_triplets_valid = valid_triplets.loc[valid_triplets.index == drug].head(100)
        filtered_triplets_test = test_triplets.loc[test_triplets.index == drug].head(100)
        tails = preds.tail_label.values
        
        filtered_triplets_valid['index'] = list(range(filtered_triplets_valid.shape[0]))
        filtered_triplets_valid = filtered_triplets_valid.set_index('index')
        
        filtered_triplets_test['index'] = list(range(filtered_triplets_test.shape[0]))
        filtered_triplets_test = filtered_triplets_test.set_index('index')

        # count triplets (drug, interacts, tail_id) that are in validation/testing dataset
        for tail_id in tails:
            idx = filtered_triplets_valid.index[filtered_triplets_valid['tail'] == tail_id].tolist()
            if len(idx) > 0:
                mrr_valid.append(1/(idx[0]+1))
            
            else:
                mrr_valid.append(1/100)
                
            idx = filtered_triplets_test.index[filtered_triplets_test['tail'] == tail_id].tolist()
            if len(idx) > 0:
                mrr_test.append(1/(idx[0]+1))
            
            else:
                mrr_test.append(1/100)
                
    print('MRR on valid dataset:', np.mean(mrr_valid))
    print('MRR on test dataset:', np.mean(mrr_test))

In [24]:
compute_mrr(common_drugs_ids, prediction_dir, specification, model_name, data_dir, triplets_dir)


MRR on valid dataset: 0.03112084001039423
MRR on test dataset: 0.03007297136760064


### Human evaluation
Transform the predictions into a human readable form (id -> name), so a human can decide if the predictions make sense (based on some medical knowledge).

In [106]:
predictions_human_all = pd.DataFrame()

for drug_id in common_drugs_ids:
    prediction_file = prediction_dir + specification + data_dir + model_name + drug_id + '_interacts_' + specification + '.csv'
#     prediction_file = prediction_dir + specification + data_dir + 'common_preds/common_preds_' + drug_id + '.csv'
    predictions = get_predictions(prediction_file)
    
    if predictions is None:
        continue
        
    # filter out known triplets 
    for row in predictions.iterrows():
        snd_idx = row[1].tail_label
        known_triplet = check_known_triplets(drug, snd_idx, 'interactions')
        
        if known_triplet:
            predictions.drop(row[0])
        
        
    predictions_human = pd.DataFrame()
    k = 10
    
    predictions = predictions.head(k)
    drug_name = drug_id_map_dict[drug_id]
    tail_names = assign_names(predictions)
    
    predictions_human['drug1'] = list(itertools.repeat(drug_name, k))
    predictions_human['drug1_id'] = list(itertools.repeat(drug_id, k))
    predictions_human['relation'] = list(itertools.repeat('interacts', k))
    predictions_human['drug2'] = tail_names
    predictions_human['drug2_id'] = predictions.tail_label.values
    
    predictions_human_all = pd.concat([predictions_human_all, predictions_human])

predictions_human_all

Unnamed: 0,drug1,drug1_id,relation,drug2,drug2_id
0,Anastrozole,DB01217,interacts,Terbinafine,DB00857
1,Anastrozole,DB01217,interacts,Hydroxyurea,DB01005
2,Anastrozole,DB01217,interacts,Carbamazepine,DB00564
3,Anastrozole,DB01217,interacts,Irinotecan,DB00762
4,Anastrozole,DB01217,interacts,Mycophenolic acid,DB01024
...,...,...,...,...,...
5,Cyclophosphamide,DB00531,interacts,Palifermin,DB00039
6,Cyclophosphamide,DB00531,interacts,Aminosalicylic acid,DB00233
7,Cyclophosphamide,DB00531,interacts,Metformin,DB00331
8,Cyclophosphamide,DB00531,interacts,Raltitrexed,DB00293


In [107]:
predictions_human_all_food = pd.DataFrame()

for food_id in foods:
    prediction_file = prediction_dir + specification + data_dir + model_name + food_id + '_interacts_' + specification + '.csv'
#     prediction_file = prediction_dir + specification + data_dir + 'common_preds/common_preds_' + food_id + '.csv'

    predictions = get_predictions(prediction_file)
    
    if predictions is None or predictions.shape[0] == 0:
        continue
    
    # filter out known triplets 
    for row in predictions.iterrows():
        snd_idx = row[1].tail_label
        known_triplet = check_known_triplets(drug, snd_idx, 'interactions')
        
        if known_triplet:
            predictions.drop(row[0])
            
    
    predictions_human = pd.DataFrame()
    k = predictions.shape[0]
    
    predictions = predictions.drop(predictions[predictions.node_type != 'drug'].index)
    
    if k > 10:
        k = 10
        predictions = predictions.head(k)

    
    if 'FDB' in food_id:
        food_name = compound_id_map_dict[food_id]
    elif 'FOOD' in food_id:
        food_name = food_id_map_dict[food_id]
    else:
        food_name = food_id
    tail_names = assign_names(predictions)
    
    predictions_human['food'] = list(itertools.repeat(food_name, k))
    predictions_human['food_id'] = list(itertools.repeat(food_id, k))
    predictions_human['relation'] = list(itertools.repeat('interacts', k))
    predictions_human['drug2'] = tail_names
    predictions_human['drug2_id'] = predictions.tail_label.values
    
    predictions_human_all_food = pd.concat([predictions_human_all_food, predictions_human])

predictions_human_all_food

Unnamed: 0,food,food_id,relation,drug2,drug2_id
0,Grapefruit,FOOD00256,interacts,Dexamethasone acetate,DB14649
1,Grapefruit,FOOD00256,interacts,Solriamfetol,DB14754
2,Grapefruit,FOOD00256,interacts,Desogestrel,DB00304
3,Grapefruit,FOOD00256,interacts,Testosterone cypionate,DB13943
4,Grapefruit,FOOD00256,interacts,Acenocoumarol,DB01418
5,Grapefruit,FOOD00256,interacts,Vinorelbine,DB00361
6,Grapefruit,FOOD00256,interacts,Zidovudine,DB00495
7,Grapefruit,FOOD00256,interacts,Lepirudin,DB00001
8,Grapefruit,FOOD00256,interacts,Voriconazole,DB00582
9,Grapefruit,FOOD00256,interacts,Vitamin E,DB00163


In [108]:
predictions_human_all.to_csv(prediction_dir + "predictions_human_all_interactions_drugs.csv")
predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_interactions_foods.csv")

# predictions_human_all.to_csv(prediction_dir + "predictions_human_all_drugbank_drugs.csv")
# predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_drugbank_foods.csv")

# predictions_human_all.to_csv(prediction_dir + "predictions_human_all_drugs_common_interactions.csv")
# predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_foods_common_interactions.csv")

In [None]:
# find foods for each food compound in predictions_human_all_food
food_compound_map = pd.read_csv(triplets_dir + 'food_compound.tsv', sep='\t', index_col=[0])

compound_to_food = {}

for index, row in food_compound_map.iterrows():
    compound_id = row['compound_id']
    food_id = row['food_id']
    
    if compound_id not in compound_to_food:
        compound_to_food[compound_id] = [food_id_map_dict[food_id]]
    else:
        compound_to_food[compound_id].append(food_id_map_dict[food_id])
        
compound_to_food

# TODO: sort by an amount of a comound in a food (take first x) ??

In [None]:
with open(prediction_dir + 'compound_to_food.json', 'w') as f:
    json.dump(compound_to_food, f, indent=4)

In [3]:
train = pd.read_csv(triplets_dir + 'train_' + 'drugbank' + '.tsv', sep='\t')
valid = pd.read_csv(triplets_dir + 'valid_' + 'drugbank' + '.tsv', sep='\t')
test = pd.read_csv(triplets_dir + 'test_' + 'drugbank' + '.tsv', sep='\t')

# find leaked triplets in test data
num_leaked_triplets = 0
for row in train.iterrows():
    h = row[1].values[0]
    t = row[1].values[2]
    tmp = test.loc[test['head'] == h]
    leaked = tmp.loc[test['tail'] == t].any().sum()
    num_leaked_triplets += leaked
    if leaked:
        print('Same:',row[1].values)
    
    tmp = test.loc[test['tail'] == h]
    leaked = tmp.loc[test['head'] == t].any().sum()
    num_leaked_triplets += leaked
    if leaked:
        print('Inversed:', row[1].values)

num_leaked_triplets    

KeyboardInterrupt: 