In [124]:
import pandas as pd
import itertools
import numpy as np
import json

In [61]:
prediction_dir = '../predictions/'
triplets_dir = '../data/triplets/'
drugbank_dir = '../data/drugbank/'
specification = 'best_pipeline2.2-200epochs' #'ogb_settings' 
data_dir = '/drugbank/'
food_compounds_names_path = '../data/triplets/compounds_names.tsv'

### Create id - name mappings

In [13]:
drug_name_map = pd.read_csv(drugbank_dir + 'drug_id_name_map.csv', sep=',', index_col=[0])
food_name_map = pd.read_csv(triplets_dir + 'food_name.tsv', sep='\t', index_col=[0])
food_compound_map = pd.read_csv(triplets_dir + 'compounds_names.tsv', sep='\t', index_col=[0])

In [14]:
drug_ids = drug_name_map.id
drug_names = drug_name_map.drug_name
drug_id_map_dict = dict(zip(drug_ids, drug_names))

food_ids = food_name_map.public_id
food_names = food_name_map.name
food_id_map_dict = dict(zip(food_ids, food_names))

compound_ids = food_compound_map.compound_id
compound_names = food_compound_map.name
compound_id_map_dict = dict(zip(compound_ids, compound_names))

### Load drugs and foods 

In [64]:
common_drugs = pd.read_csv('../data/common_drugs_num_interactions.csv', sep=';')
common_drugs = common_drugs.dropna()
common_drugs_ids = common_drugs.db_id.values

common_drugs_ids = list(set(common_drugs_ids))
common_drugs_ids[:10]

['DB01307',
 'DB00881',
 'DB08907',
 'DB01306',
 'DB01698',
 'DB01009',
 'DB00836',
 'DB01200',
 'DB00727',
 'DB06713']

In [65]:
with open('../data/foods4predictions.txt', 'r') as f:
    foods = f.readlines()

foods = [food.strip() for food in foods]
foods[:10]

['FDB012878',
 'FDB015831',
 'FDB012266',
 'FDB011824',
 'FDB001982',
 'FDB003636',
 'FDB000497',
 'FDB012196',
 'FDB015887',
 'FDB000082']

### Load predictions for specific drug/food

In [75]:
def get_predictions(prediction_file):
    try:
        predictions = pd.read_csv(prediction_file, sep=',', index_col=[0])
    except:
        return None
    # keep just drug/food/food compound predictions
    predictions['node_type'] = list(itertools.repeat('xxx', predictions.shape[0]))
    predictions.loc[predictions['tail_label'].str.contains("DB\d+", regex=True), 'node_type'] = "drug"
    predictions.loc[predictions['tail_label'].str.contains("FDB"), 'node_type'] = "food_compound"
    predictions.loc[predictions['tail_label'].str.contains("FOOD"), 'node_type'] = "food"
    predictions = predictions.loc[predictions['node_type'] != 'xxx']
    
    return predictions

# assign entity names to ids
def assign_names(predictions, show=False):
    tail_names = []
    for row in predictions.iterrows():
        tail = row[1].tail_label
        node_type = row[1].node_type

        if node_type == 'drug':
            tail_name = drug_id_map_dict[tail]
        elif node_type == 'food':
            tail_name = food_id_map_dict[tail]
        else:
            tail_name = compound_id_map_dict[tail]

        if show:
            print(tail, tail_name)    
        tail_names.append(tail_name)
    return tail_names

In [77]:
# drug predictions
drug_id = common_drugs_ids[50]
prediction_file = prediction_dir + specification + data_dir +'complex_' + drug_id + '_interacts_' + specification + '.csv'

print('Interactions with', drug_id_map_dict[drug_id])

predictions = get_predictions(prediction_file)
print(predictions.head(10))
print()
_ = assign_names(predictions, show=True)

Interactions with Clomipramine
      tail_id     score tail_label  in_validation  in_testing node_type
2256     2256  9.523676    DB00312          False       False      drug
2963     2963  8.502615    DB01054          False       False      drug
4929     4929  8.472841    DB15982          False       False      drug
3076     3076  8.296550    DB01173          False       False      drug
3565     3565  8.292852    DB06402          False       False      drug
3420     3420  8.248529    DB04839          False       False      drug
2931     2931  8.229597    DB01017          False       False      drug
2757     2757  8.188326    DB00835          False       False      drug
2377     2377  8.074543    DB00435          False       False      drug
2253     2253  8.066900    DB00308          False       False      drug

DB00312 Pentobarbital
DB01054 Nitrendipine
DB15982 Berotralstat
DB01173 Orphenadrine
DB06402 Telavancin
DB04839 Cyproterone acetate
DB01017 Minocycline
DB00835 Brompheniramine


In [37]:
tails = predictions.tail_label.values

In [101]:
# food predicitons
food_id = foods[37]
prediction_file = prediction_dir + specification + data_dir +'complex_' + food_id + '_interacts_' + specification + '.csv'

print('Interactions with', compound_id_map_dict[food_id])

predictions = get_predictions(prediction_file)
print(predictions.head())
print()
assign_names(predictions)

Interactions with delta-Carotene
      tail_id     score tail_label  in_validation  in_testing      node_type
4954     4954  0.624877    DB16746          False       False           drug
6854     6854  0.522141  FDB001936          False       False  food_compound
7414     7414  0.513568  FOOD00394          False       False           food
4887     4887  0.457773    DB15270          False       False           drug
6981     6981  0.429160  FDB012521          False       False  food_compound



['Elivaldogene autotemcel',
 'Stigmasterol',
 'Lambsquarters',
 'Efgartigimod alfa',
 'Campesterol',
 'Other fruit product',
 'Beractant',
 'Erenumab',
 'Peppermint',
 'Pie',
 'Quercetin 3-rutinoside',
 'Pectic acid',
 'Pantothenic acid',
 'Adzuki bean',
 'beta-Sitosterol',
 'Luteolin',
 'L-Tyrosine',
 'Hazelnut',
 'Silicon',
 'L-Histidine',
 'Quercetin',
 'Pear',
 'Caffeic acid',
 'Margarine-like spread',
 'Ginger',
 'alpha-Carotene',
 'Dried milk',
 'L-Cystine',
 'Spectinomycin',
 'beta-Carotene',
 'Indocyanine green',
 'L-Lysine',
 'D-Fructose',
 'Prasterone',
 'Hushpuppy',
 'Empanada',
 'L-Methionine',
 'Wild celery',
 'Dodecanoic acid',
 'Econazole',
 'Phyllo dough',
 'Burdock',
 'Hepatitis A Vaccine',
 'Oat',
 'Centella asiatica',
 'Glycine',
 'Waffle',
 'Pantothenic acid',
 'Riboflavine',
 'Spearmint',
 'Ergosterol',
 'Parsnip',
 'Mupirocin',
 'Nitrogen',
 'Brincidofovir',
 '4(10)-Thujene',
 'Japanese persimmon',
 'Common cabbage',
 'Malvidin',
 'Betulin',
 'Hyaluronic acid',
 '

In [39]:
# # check if the predicted drug/food is in different interaction with the same drug in the training data

data = 'drugbank'
idx = drug_id
snd_idx = tails[60]

train_triplets = pd.read_csv(triplets_dir + 'train_' + data + '.tsv', sep='\t', index_col=[0])
valid_triplets = pd.read_csv(triplets_dir + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
test_triplets = pd.read_csv(triplets_dir + 'test_' + data + '.tsv', sep='\t', index_col=[0])

filtered_triplets = train_triplets.loc[train_triplets.index == idx]
in_train = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

filtered_triplets = valid_triplets.loc[valid_triplets.index == idx]
in_valid = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

filtered_triplets = test_triplets.loc[test_triplets.index == idx]
in_test = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

print(f'Relation in triplets:')
print(f'- train:', 'yes' if in_train else 'no')
print(f'- valid:', 'yes' if in_valid else 'no')
print(f'- test:', 'yes' if in_test else 'no')
      

Relation in triplets:
- train: no
- valid: yes
- test: no


### Metrics

In [62]:
# https://pykeen.readthedocs.io/en/stable/api/pykeen.metrics.ranking.HitsAtK.html
# The number of triplets that are relevant in predictions in the top K positions * 1/total number of predictions (=100).
def compute_hits_at_k(common_drugs_ids, prediction_dir, specification, data_dir, triplets_dir, k=100):
    hits_k_valid = []
    hits_k_test = []

    data = data_dir.replace('/', '')
    valid_triplets = pd.read_csv(triplets_dir + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
    test_triplets = pd.read_csv(triplets_dir + 'test_' + data + '.tsv', sep='\t', index_col=[0])

    for drug in common_drugs_ids:
        prediction_file = prediction_dir + specification + data_dir +'complex_' + drug + '_interacts_' + specification + '.csv'
        preds = get_predictions(prediction_file)
        
        if preds is None:
    #         print(drug)
            continue

        preds_count = preds.shape[0]
        
        # we are intersested only in first k predictions
        preds = preds.head(k)

        in_valid_count = 0
        in_test_count = 0

        filtered_triplets_valid = valid_triplets.loc[valid_triplets.index == drug]
        filtered_triplets_test = test_triplets.loc[test_triplets.index == drug]
        tails = preds.tail_label.values

        # count triplets (drug, interacts, tail_id) that are in validation/testing dataset
        for tail_id in tails:
            if filtered_triplets_valid.loc[filtered_triplets_valid['tail'] == tail_id].any().sum() > 0:
                in_valid_count += 1
            if filtered_triplets_test.loc[filtered_triplets_test['tail'] == tail_id].any().sum() > 0:
                in_test_count += 1

        hits_k_valid.append(in_valid_count/preds_count)
        hits_k_test.append(in_test_count/preds_count)


    print(f'Avg. hits@{k} on validation dataset: {np.mean(hits_k_valid)}')
    print(f'Avg. hits@{k} on test dataset: {np.mean(hits_k_test)}')

In [63]:
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, data_dir, triplets_dir, 1)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, data_dir, triplets_dir, 5)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, data_dir, triplets_dir, 10)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, data_dir, triplets_dir, 30)
compute_hits_at_k(common_drugs_ids, prediction_dir, specification, data_dir, triplets_dir, 100)

Avg. hits@1 on validation dataset: 0.0053124999999999995
Avg. hits@1 on test dataset: 0.0053124999999999995
Avg. hits@5 on validation dataset: 0.023125
Avg. hits@5 on test dataset: 0.029375
Avg. hits@10 on validation dataset: 0.05
Avg. hits@10 on test dataset: 0.056875
Avg. hits@30 on validation dataset: 0.15375
Avg. hits@30 on test dataset: 0.1625
Avg. hits@100 on validation dataset: 0.5125
Avg. hits@100 on test dataset: 0.51625


In [66]:
# TODO: correct MRR !!

In [17]:
# MRR (https://en.wikipedia.org/wiki/Mean_reciprocal_rank)
# if the predicted triplet isn't in the first 100 test data -> index = 100

mrr = []

for drug in common_drugs_ids:
    prediction_file = prediction_dir + specification + data_dir +'complex_' + drug + '_interacts_' + specification + '.csv'
    preds = get_predictions(prediction_file)
    
    if preds is None:
        continue
    
    try:
        idx = 1 / (list(preds.in_testing).index(True) + 1)
    except:
        idx = 1 / 100
    mrr.append(idx)
    
print('MRR:', np.mean(mrr))

MRR: 0.009999999999999998


### Human evaluation
Transform the predictions into a human readable form (id -> name), so a human can decide if the predictions make sense (based on some medical knowledge).

In [88]:
predictions_human_all = pd.DataFrame()

for drug_id in common_drugs_ids:
    prediction_file = prediction_dir + specification + data_dir +'complex_' + drug_id + '_interacts_' + specification + '.csv'
    predictions = get_predictions(prediction_file)
    
    if predictions is None:
        continue
        
    predictions_human = pd.DataFrame()
    k = 10
    
    predictions = predictions.head(k)
    drug_name = drug_id_map_dict[drug_id]
    tail_names = assign_names(predictions)
    
    predictions_human['drug1'] = list(itertools.repeat(drug_name, k))
    predictions_human['drug1_id'] = list(itertools.repeat(drug_id, k))
    predictions_human['relation'] = list(itertools.repeat('interacts', k))
    predictions_human['drug2'] = tail_names
    predictions_human['drug2_id'] = predictions.tail_label.values
    
    predictions_human_all = pd.concat([predictions_human_all, predictions_human])

predictions_human_all

Unnamed: 0,drug1,drug1_id,relation,drug2,drug2_id
0,Verapamil,DB00661,interacts,Picosulfuric acid,DB09268
1,Verapamil,DB00661,interacts,Methylphenobarbital,DB00849
2,Verapamil,DB00661,interacts,Phenformin,DB00914
3,Verapamil,DB00661,interacts,Rabeprazole,DB01129
4,Verapamil,DB00661,interacts,Hydrocortisone acetate,DB14539
...,...,...,...,...,...
5,Rivaroxaban,DB06228,interacts,Ibritumomab tiuxetan,DB00078
6,Rivaroxaban,DB06228,interacts,Vorinostat,DB02546
7,Rivaroxaban,DB06228,interacts,Lenalidomide,DB00480
8,Rivaroxaban,DB06228,interacts,Calcium Phosphate,DB11348


In [113]:
predictions_human_all_food = pd.DataFrame()

for food_id in foods:
    prediction_file = prediction_dir + specification + data_dir +'complex_' + food_id + '_interacts_' + specification + '.csv'
    predictions = get_predictions(prediction_file)
    
    if predictions is None:
        continue
        
    predictions_human = pd.DataFrame()
    k = 10
    
    predictions = predictions.drop(predictions[predictions.node_type != 'drug'].index)
    
    if k < predictions.shape[0]:
        predictions = predictions.head(k)
    
    if 'FDB' in food_id:
        food_name = compound_id_map_dict[food_id]
    elif 'FOOD' in food_id:
        food_name = food_id_map_dict[food_id]
    else:
        food_name = food_id
    tail_names = assign_names(predictions)
    
    predictions_human['food'] = list(itertools.repeat(food_name, k))
    predictions_human['food_id'] = list(itertools.repeat(food_id, k))
    predictions_human['relation'] = list(itertools.repeat('interacts', k))
    predictions_human['drug2'] = tail_names
    predictions_human['drug2_id'] = predictions.tail_label.values
    
    predictions_human_all_food = pd.concat([predictions_human_all_food, predictions_human])

predictions_human_all_food

Unnamed: 0,food,food_id,relation,drug2,drug2_id
0,Bentonite,FDB012878,interacts,Bendazac,DB13501
1,Bentonite,FDB012878,interacts,Avapritinib,DB15233
2,Bentonite,FDB012878,interacts,Clobetasol propionate,DB01013
3,Bentonite,FDB012878,interacts,Nandrolone decanoate,DB08804
4,Bentonite,FDB012878,interacts,Cholestyramine,DB01432
...,...,...,...,...,...
5,St. John's Wort,St. John's Wort,interacts,Larotrectinib,DB14723
6,St. John's Wort,St. John's Wort,interacts,Alectinib,DB11363
7,St. John's Wort,St. John's Wort,interacts,Busulfan,DB01008
8,St. John's Wort,St. John's Wort,interacts,Argatroban,DB00278


In [126]:
predictions_human_all.to_csv(prediction_dir + "predictions_human_all_drugs.csv")
predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_foods.csv")

In [123]:
# find foods for each food compound in predictions_human_all_food
food_compound_map = pd.read_csv(triplets_dir + 'food_compound.tsv', sep='\t', index_col=[0])

compound_to_food = {}

for index, row in food_compound_map.iterrows():
    compound_id = row['compound_id']
    food_id = row['food_id']
    
    if compound_id not in compound_to_food:
        compound_to_food[compound_id] = [food_id_map_dict[food_id]]
    else:
        compound_to_food[compound_id].append(food_id_map_dict[food_id])
        
compound_to_food

# TODO: sort by an amount of a comound in a food (take first x) ??

{'FDB013255': ['Angelica',
  'Savoy cabbage',
  'Kiwi',
  'Chives',
  'Lemon verbena',
  'Wild celery',
  'Horseradish',
  'Tarragon',
  'Common cabbage',
  'Cauliflower',
  'Brussel sprouts',
  'Kohlrabi',
  'Pepper',
  'Roman camomile',
  'Watermelon',
  'Lemon',
  'Muskmelon',
  'Cucumber',
  'Cucurbita',
  'Cumin',
  'Lemon grass',
  'Globe artichoke',
  'Wild carrot',
  'Fennel',
  'Lettuce',
  'Lentils',
  'Flaxseed',
  'Mexican oregano',
  'Apple',
  'German camomile',
  'Cornmint',
  'Spearmint',
  'Peppermint',
  'Sweet basil',
  'Olive',
  'Common oregano',
  'Parsley',
  'Common bean',
  'Anise',
  'Pistachio',
  'Pomegranate',
  'Rosemary',
  'Common sage',
  'Potato',
  'Spinach',
  'Dandelion',
  'Cocoa bean',
  'Common thyme',
  'Fenugreek',
  'Common grape',
  'Celery stalks',
  'Garden onion (var.)',
  'Carrot',
  'Celery leaves',
  'Pak choy',
  'Red beetroot',
  'Alfalfa',
  'Ginkgo nuts',
  'Sacred lotus',
  'Common salsify',
  'Green bell pepper',
  'Yellow bell pe

In [125]:
with open(prediction_dir + 'compound_to_food.json', 'w') as f:
    json.dump(compound_to_food, f, indent=4)