In [1]:
import pandas as pd
import itertools
import numpy as np
import json

In [144]:
prediction_dir = '../predictions/'
triplets_dir = '../data/triplets/'
drugbank_dir = '../data/drugbank/'
specification = 'best_pipeline4-run2'
data_dir = '/drugbank/'
food_compounds_names_path = '../data/triplets/compounds_names.tsv'
model_name = 'rotate_'

### Create id - name mappings

In [145]:
drug_name_map = pd.read_csv(drugbank_dir + 'drug_id_name_map.csv', sep=',', index_col=[0])
food_name_map = pd.read_csv(triplets_dir + 'food_name.tsv', sep='\t', index_col=[0])
food_compound_map = pd.read_csv(triplets_dir + 'compounds_names.tsv', sep='\t', index_col=[0])

In [146]:
drug_ids = drug_name_map.id
drug_names = drug_name_map.drug_name
drug_id_map_dict = dict(zip(drug_ids, drug_names))

food_ids = food_name_map.public_id
food_names = food_name_map.name
food_id_map_dict = dict(zip(food_ids, food_names))

compound_ids = food_compound_map.compound_id
compound_names = food_compound_map.name
compound_id_map_dict = dict(zip(compound_ids, compound_names))

### Load drugs and foods 

In [147]:
# common_drugs = pd.read_csv('../data/common_drugs_num_interactions.csv', sep=';')
common_drugs = pd.read_csv('../data/drugs4prediction.csv', sep=';')
common_drugs = common_drugs.dropna()
common_drugs_ids = common_drugs.db_id.values

common_drugs_ids = list(set(common_drugs_ids))
common_drugs_ids[:10]

['DB01006  ',
 'DB00642',
 'DB00544',
 'DB00661',
 'DB00441',
 'DB01101',
 'DB00563',
 'DB00958',
 'DB01217',
 'DB00531']

In [148]:
with open('../data/foods4predictions-2.txt', 'r') as f:
    foods = f.readlines()

foods = [food.strip() for food in foods]
foods[:10]

['FOOD00055',
 'FOOD00255',
 'FOOD00256',
 "St. John's Wort",
 'FOOD00181',
 'Dandelion',
 "Cat's claw",
 'Nettle',
 'Cordyceps',
 'Reishi Mushroom']

In [149]:
food_id_map_dict[foods[4]]

'Dandelion'

### Load predictions for specific drug/food

In [150]:
def get_predictions(prediction_file):
    try:
        predictions = pd.read_csv(prediction_file, sep=',', index_col=[0])
    except:
        return None
    # keep just drug/food/food compound predictions
    predictions['node_type'] = list(itertools.repeat('xxx', predictions.shape[0]))
    predictions.loc[predictions['tail_label'].str.contains("DB\d+", regex=True), 'node_type'] = "drug"
    predictions.loc[predictions['tail_label'].str.contains("FDB"), 'node_type'] = "food_compound"
    predictions.loc[predictions['tail_label'].str.contains("FOOD"), 'node_type'] = "food"
    predictions = predictions.loc[predictions['node_type'] != 'xxx']
    
    return predictions

# assign entity names to ids
def assign_names(predictions, show=False):
    tail_names = []
    for row in predictions.iterrows():
        tail = row[1].tail_label
        node_type = row[1].node_type

        if node_type == 'drug':
            tail_name = drug_id_map_dict.get(tail, tail)
        elif node_type == 'food':
            tail_name = food_id_map_dict.get(tail, tail)
        else:
            tail_name = compound_id_map_dict.get(tail, tail)

        if show:
            print(tail, tail_name)    
        tail_names.append(tail_name)
    return tail_names

In [151]:
# drug predictions
drug_id = common_drugs_ids[9]
prediction_file = prediction_dir + specification + data_dir + model_name + drug_id + '_interacts_' + specification + '.csv'

print('Interactions with', drug_id_map_dict[drug_id])

predictions = get_predictions(prediction_file)
print(predictions.head(10))
print()
_ = assign_names(predictions, show=False)

Interactions with Cyclophosphamide
       tail_id       score tail_label  in_validation  in_testing node_type
5675      5675 -105.695244    DB04836          False       False      drug
3235      3235 -105.768951    DB00054          False       False      drug
3529      3529 -105.810745    DB00357          False       False      drug
3571      3571 -105.848656    DB00400          False       False      drug
4182      4182 -105.881790    DB01015          False       False      drug
3589      3589 -105.918869    DB00418          False       False      drug
3562      3562 -105.930542    DB00391          False       False      drug
3598      3598 -106.036926    DB00427          False       False      drug
10868    10868 -106.067764    DB14568          False       False      drug
4198      4198 -106.214905    DB01032          False       False      drug



In [152]:
tails = predictions.tail_label.values

In [153]:
# food predicitons
idx=2
food_id = foods[idx]
prediction_file = prediction_dir + specification + data_dir + model_name + food_id + '_interacts_' + specification + '.csv'

food_name = food_id
if "FOOD" in food_id:
    food_name = food_id_map_dict[food_id]
print('Interactions with', food_name)

predictions = get_predictions(prediction_file)
print(predictions.head())
print()
assign_names(predictions)

Interactions with Grapefruit
       tail_id       score tail_label  in_validation  in_testing node_type
3739      3739 -107.699249    DB00570          False       False      drug
4292      4292 -107.819351    DB01128          False       False      drug
3439      3439 -108.764252    DB00266          False       False      drug
10463    10463 -108.867500    DB13884          False       False      drug
3773      3773 -108.982513    DB00604          False       False      drug



['Vinblastine',
 'Bicalutamide',
 'Dicoumarol',
 'Albutrepenonacog alfa',
 'Cisapride',
 'Heparin',
 'Teniposide',
 'Cephalexin',
 'Phenprocoumon',
 'Doxorubicin',
 'Theophylline',
 'Doxazosin',
 'Erlotinib',
 'Pretomanid',
 'Tamoxifen',
 'Eucalyptus oil',
 'Bortezomib',
 'Drotrecogin alfa',
 'Prednisone',
 'Primaquine',
 'Cabergoline',
 'Atorvastatin',
 'Cerivastatin',
 'Reteplase',
 'Nateglinide',
 'Sildenafil',
 'Tolbutamide',
 'Irinotecan',
 'Argatroban',
 'Aminophenazone',
 'Perhexiline',
 'Methotrexate',
 'Docetaxel',
 'Carboplatin',
 'Esomeprazole',
 'Lovastatin',
 'Fentanyl',
 'Methysergide',
 'Digitoxin',
 'Pentoxifylline',
 'Dabrafenib',
 'Cefpirome',
 'Lumiracoxib',
 'Capecitabine',
 'Vindesine',
 'Zidovudine',
 'Cefoxitin',
 'Synthetic Conjugated Estrogens, B',
 'Pantoprazole',
 'Anakinra',
 'Urokinase',
 'Selenium',
 'Carbamazepine',
 'Saquinavir',
 'Cefamandole',
 'Epoprostenol',
 'Vorapaxar',
 'Desogestrel',
 'Rolapitant',
 'Propranolol',
 'Voriconazole',
 'Levacetylmeth

In [157]:
# # check if the predicted drug/food is in different interaction with the same drug in the training data
def check_known_triplets(idx, snd_idx, data):
    
    run = 'run2/'
    
    train_triplets = pd.read_csv(triplets_dir + run + 'train_' + data + '.tsv', sep='\t', index_col=[0])
    valid_triplets = pd.read_csv(triplets_dir + run + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
    test_triplets = pd.read_csv(triplets_dir + run + 'test_' + data + '.tsv', sep='\t', index_col=[0])

    filtered_triplets = train_triplets.loc[train_triplets.index == idx]
    in_train = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

    filtered_triplets = valid_triplets.loc[valid_triplets.index == idx]
    in_valid = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

    filtered_triplets = test_triplets.loc[test_triplets.index == idx]
    in_test = filtered_triplets.loc[filtered_triplets['tail'] == snd_idx].any().sum()

#     print(f'Relation in triplets:')
#     print(f'- train:', 'yes' if in_train else 'no')
#     print(f'- valid:', 'yes' if in_valid else 'no')
#     print(f'- test:', 'yes' if in_test else 'no')

    in_triplets = in_train or in_valid or in_test

    # find also symetric relations
    filtered_triplets = train_triplets.loc[train_triplets.index == snd_idx]
    in_train = filtered_triplets.loc[filtered_triplets['tail'] == idx].any().sum()

    filtered_triplets = valid_triplets.loc[valid_triplets.index == snd_idx]
    in_valid = filtered_triplets.loc[filtered_triplets['tail'] == idx].any().sum()

    filtered_triplets = test_triplets.loc[test_triplets.index == snd_idx]
    in_test = filtered_triplets.loc[filtered_triplets['tail'] == idx].any().sum()

#     print(f'Symetric relation in triplets:')
#     print(f'- train:', 'yes' if in_train else 'no')
#     print(f'- valid:', 'yes' if in_valid else 'no')
#     print(f'- test:', 'yes' if in_test else 'no')
    
    if in_triplets or in_train or in_valid or in_test:
        return True
    
    return False
    


In [155]:
data = 'drugbank'
idx = 'DB01217'
snd_idx = 'DB01062'
print('Prediction:', idx, 'interacts with', snd_idx)
check_known_triplets(idx, snd_idx, data)

Prediction: DB01217 interacts with DB01062


True

### Metrics

In [178]:
def relevant_predictions_k(common_drugs_ids, prediction_dir, specification, model_name, data_dir, run, triplets_dir, k=10):
    relevant = []
    

    data = data_dir.replace('/', '')

    valid_triplets = pd.read_csv(triplets_dir + run + 'valid_' + data + '.tsv', sep='\t', index_col=[0])
    test_triplets = pd.read_csv(triplets_dir + run + 'test_' + data + '.tsv', sep='\t', index_col=[0])
    
    valid_triplets2 = valid_triplets.reset_index().set_index('tail')
    test_triplets2 = test_triplets.reset_index().set_index('tail')

    for drug in common_drugs_ids:
        prediction_file = prediction_dir + specification + data_dir + model_name + drug + '_interacts_' + specification + '.csv'
        preds = get_predictions(prediction_file)
        
        if preds is None:
            continue
        
        # we are intersested only in first k predictions
        preds = preds.head(k)

        in_valid_count = 0
        in_test_count = 0
        
        tails = preds.tail_label.values

        filtered_triplets_valid = valid_triplets.loc[valid_triplets.index == drug]
        filtered_triplets_test = test_triplets.loc[test_triplets.index == drug]
        

        # count triplets (drug, interacts, tail_id) that are in validation/testing dataset
        for tail_id in tails:
            if filtered_triplets_valid.loc[filtered_triplets_valid['tail'] == tail_id].any().sum() > 0:
                in_valid_count += 1
            if filtered_triplets_test.loc[filtered_triplets_test['tail'] == tail_id].any().sum() > 0:
                in_test_count += 1
        
        # check also symmetric triplets
        filtered_triplets_valid2 = valid_triplets2.loc[valid_triplets2.index == drug]
        filtered_triplets_test2 = test_triplets2.loc[test_triplets2.index == drug]

        # count triplets (drug, interacts, tail_id) that are in validation/testing dataset
        for tail_id in tails:
            if filtered_triplets_valid2.loc[filtered_triplets_valid2['head'] == tail_id].any().sum() > 0:
                in_valid_count += 1
            if filtered_triplets_test2.loc[filtered_triplets_test2['head'] == tail_id].any().sum() > 0:
                in_test_count += 1

        relevant.append((in_valid_count + in_test_count)/k)

#     print(f'Avg. percentage of relevant triplets for first {k} predictions: {np.mean(relevant)}')
#     print(f'Max. percentage of relevant triplets for first {k} predictions: {np.max(relevant)}')
#     print(f'Min. percentage of relevant triplets for first {k} predictions: {np.min(relevant)}')
    
    return np.mean(relevant)


# relevant_predictions_k(common_drugs_ids, prediction_dir, specification + 'run2', model_name, data_dir, 'run2/', triplets_dir, 10)


In [182]:
runs = ['run1', 'run2', 'run3']
specification = 'best_pipeline4-'
data_dirs = ['/interactions/', '/drugbank/', '/hetionet/', '/biokg/']
model_names = ['rotate_', 'complex_']

for model_name in model_names:
    for data_dir in data_dirs:
        rel_10_list = []
        for run in runs:
            
#             rel1 = relevant_predictions_k(common_drugs_ids, prediction_dir, specification + run, model_name, data_dir, run + '/', triplets_dir, 1)
            rel10 = relevant_predictions_k(common_drugs_ids, prediction_dir, specification + run, model_name, data_dir, run + '/', triplets_dir, 10)
#             rel20 = relevant_predictions_k(common_drugs_ids, prediction_dir, specification + run, model_name, data_dir, run + '/', triplets_dir, 20)
#             rel100 = relevant_predictions_k(common_drugs_ids, prediction_dir, specification + run, model_name, data_dir, run + '/', triplets_dir, 100)

            rel_10_list.append(rel10)
                
        print(f"{model_name} {data_dir}: {np.mean(rel_10_list)}")
            
            

rotate_ /interactions/: 0.8666666666666667
rotate_ /drugbank/: 0.8222222222222223
rotate_ /hetionet/: nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


rotate_ /biokg/: nan
complex_ /interactions/: 0.6499999999999999
complex_ /drugbank/: 0.5166666666666667
complex_ /hetionet/: 0.8527777777777779


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


complex_ /biokg/: nan


### Human evaluation
Transform the predictions into a human readable form (id -> name), so a human can decide if the predictions make sense (based on some medical knowledge).

In [52]:
predictions_human_all = pd.DataFrame()

for drug_id in common_drugs_ids:
    prediction_file = prediction_dir + specification + data_dir + model_name + drug_id + '_interacts_' + specification + '.csv'
#     prediction_file = prediction_dir + specification + data_dir + 'common_preds/common_preds_' + drug_id + '.csv'
    predictions = get_predictions(prediction_file)
    
    if predictions is None:
        continue
        
    
    # filter out known triplets 
    for row in predictions.iterrows():
        snd_idx = row[1].tail_label
        known_triplet = check_known_triplets(drug_id, snd_idx, 'interactions')
        
        # TODO: opravit
        if known_triplet:
            predictions.drop(row[0])
  
        
    predictions_human = pd.DataFrame()
    k = 10
    
    predictions = predictions.head(k)
    drug_name = drug_id_map_dict[drug_id]
    tail_names = assign_names(predictions)
    
    predictions_human['drug1'] = list(itertools.repeat(drug_name, k))
    predictions_human['drug1_id'] = list(itertools.repeat(drug_id, k))
    predictions_human['relation'] = list(itertools.repeat('interacts', k))
    predictions_human['drug2'] = tail_names
    predictions_human['drug2_id'] = predictions.tail_label.values
    
    predictions_human_all = pd.concat([predictions_human_all, predictions_human])

predictions_human_all

[(116, tail_id                116
score           -33.843025
tail_label         DB00122
in_validation        False
in_testing           False
node_type             drug
Name: 116, dtype: object), (194, tail_id                194
score           -33.889561
tail_label         DB00233
in_validation        False
in_testing           False
node_type             drug
Name: 194, dtype: object), (145, tail_id                145
score           -33.965977
tail_label         DB00180
in_validation        False
in_testing           False
node_type             drug
Name: 145, dtype: object), (243, tail_id                243
score           -33.966263
tail_label         DB00287
in_validation        False
in_testing           False
node_type             drug
Name: 243, dtype: object), (185, tail_id                185
score           -34.103146
tail_label         DB00223
in_validation        False
in_testing           False
node_type             drug
Name: 185, dtype: object), (1645, tail_id          

In [42]:
predictions_human_all_food = pd.DataFrame()

for food_id in foods:
    prediction_file = prediction_dir + specification + data_dir + model_name + food_id + '_interacts_' + specification + '.csv'
#     prediction_file = prediction_dir + specification + data_dir + 'common_preds/common_preds_' + food_id + '.csv'

    predictions = get_predictions(prediction_file)
    
    if predictions is None or predictions.shape[0] == 0:
        continue
    
    # filter out known triplets 
    for row in predictions.iterrows():
        snd_idx = row[1].tail_label
        known_triplet = check_known_triplets(drug, snd_idx, 'interactions')
        
        if known_triplet:
            predictions.drop(row[0])
            
    
    predictions_human = pd.DataFrame()
    k = predictions.shape[0]
    
    predictions = predictions.drop(predictions[predictions.node_type != 'drug'].index)
    
    if k > 10:
        k = 10
        predictions = predictions.head(k)

    
    if 'FDB' in food_id:
        food_name = compound_id_map_dict[food_id]
    elif 'FOOD' in food_id:
        food_name = food_id_map_dict[food_id]
    else:
        food_name = food_id
    tail_names = assign_names(predictions)
    
    predictions_human['food'] = list(itertools.repeat(food_name, k))
    predictions_human['food_id'] = list(itertools.repeat(food_id, k))
    predictions_human['relation'] = list(itertools.repeat('interacts', k))
    predictions_human['drug2'] = tail_names
    predictions_human['drug2_id'] = predictions.tail_label.values
    
    predictions_human_all_food = pd.concat([predictions_human_all_food, predictions_human])

predictions_human_all_food

In [108]:
predictions_human_all.to_csv(prediction_dir + "predictions_human_all_interactions_drugs.csv")
predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_interactions_foods.csv")

# predictions_human_all.to_csv(prediction_dir + "predictions_human_all_drugbank_drugs.csv")
# predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_drugbank_foods.csv")

# predictions_human_all.to_csv(prediction_dir + "predictions_human_all_drugs_common_interactions.csv")
# predictions_human_all_food.to_csv(prediction_dir + "predictions_human_all_foods_common_interactions.csv")

In [None]:
# find foods for each food compound in predictions_human_all_food
food_compound_map = pd.read_csv(triplets_dir + 'food_compound.tsv', sep='\t', index_col=[0])

compound_to_food = {}

for index, row in food_compound_map.iterrows():
    compound_id = row['compound_id']
    food_id = row['food_id']
    
    if compound_id not in compound_to_food:
        compound_to_food[compound_id] = [food_id_map_dict[food_id]]
    else:
        compound_to_food[compound_id].append(food_id_map_dict[food_id])
        
compound_to_food

# TODO: sort by an amount of a comound in a food (take first x) ??

In [None]:
with open(prediction_dir + 'compound_to_food.json', 'w') as f:
    json.dump(compound_to_food, f, indent=4)

In [None]:
train = pd.read_csv(triplets_dir + 'train_' + 'drugbank' + '.tsv', sep='\t')
valid = pd.read_csv(triplets_dir + 'valid_' + 'drugbank' + '.tsv', sep='\t')
test = pd.read_csv(triplets_dir + 'test_' + 'drugbank' + '.tsv', sep='\t')

# find leaked triplets in test data
num_leaked_triplets = 0
for row in train.iterrows():
    h = row[1].values[0]
    t = row[1].values[2]
    tmp = test.loc[test['head'] == h]
    leaked = tmp.loc[test['tail'] == t].any().sum()
    num_leaked_triplets += leaked
    if leaked:
        print('Same:',row[1].values)
    
    tmp = test.loc[test['tail'] == h]
    leaked = tmp.loc[test['head'] == t].any().sum()
    num_leaked_triplets += leaked
    if leaked:
        print('Inversed:', row[1].values)

num_leaked_triplets    

In [None]:
# Prediciton analysis - search them in test/valid data

list_of_predictions = [
    [('DB00563', 'DB00330'), ('DB00563', 'DB00515'), ('DB00563', 'DB00530'), ('DB00563', 'DB00007'), ('DB00563', 'DB00279'), ('DB00563', 'DB00063'),
('DB00563', 'DB00436'), ('DB00563', 'DB00201'), ('DB00563', 'DB00479'), ('DB00563', 'DB00065')],
    [('DB00642', 'DB00365'), ('DB00642', 'DB00289'), ('DB00642', 'DB00461'), ('DB00642', 'DB00059'), ('DB00642', 'DB00573'), ('DB00642', 'DB00330'), ('DB00642', 'DB00095'), ('DB00642', 'DB00245'), ('DB00642', 'DB00074'), ('DB00642', 'DB00537')],
    [('DB00441', 'DB00225'), ('DB00441', 'DB00193'), ('DB00441', 'DB00017'), ('DB00441', 'DB00327'), ('DB00441', 'DB00322'), ('DB00441', 'DB00262'), ('DB00441', 'DB00007'), ('DB00441', 'DB00227'), ('DB00441', 'DB00198'), ('DB00441', 'DB00276')],
    [('DB01101', 'DB00853'), ('DB01101', 'DB00576'), ('DB01101', 'DB00675'), ('DB01101', 'DB00812'), ('DB01101', 'DB00888'), ('DB01101', 'DB00731'), ('DB01101', 'DB00700'), ('DB01101', 'DB00829'), ('DB01101', 'DB00936'), ('DB01101', 'DB00498')],
    [('DB00958', 'DB00547'), ('DB00958', 'DB00006'), ('DB00958', 'DB00583'), ('DB00958', 'DB00428'), ('DB00958', 'DB00499'), ('DB00958', 'DB00789'), ('DB00958', 'DB00229'), ('DB00958', 'DB00477'), ('DB00958', 'DB00661'), ('DB00958', 'DB00035')],
    [('DB01229', 'DB00227'), ('DB01229', 'DB00999'), ('DB01229', 'DB00714'), ('DB01229', 'DB00734'), ('DB01229', 'DB00857'), ('DB01229', 'DB00875'), ('DB01229', 'DB01261'), ('DB01229', 'DB00717'), ('DB01229', 'DB00722'), ('DB01229', 'DB00968')],
    [('DB00531', 'DB00224'), ('DB00531', 'DB00054'), ('DB00531', 'DB00391'), ('DB00531', 'DB00514'), ('DB00531', 'DB00286'), ('DB00531', 'DB00087'), ('DB00531', 'DB00490'), ('DB00531', 'DB00357'), ('DB00531', 'DB00361'), ('DB00531', 'DB00071')],
    [('DB00544', 'DB00075'), ('DB00544', 'DB00307'), ('DB00544', 'DB00526'), ('DB00544', 'DB00495'), ('DB00544', 'DB00354'), ('DB00544', 'DB00382'), ('DB00544', 'DB00206'), ('DB00544', 'DB00363'), ('DB00544', 'DB00218'), ('DB00544', 'DB00433')],
    [('DB01217', 'DB01073'), ('DB01217', 'DB01167'), ('DB01217', 'DB00613'), ('DB01217', 'DB00338'), ('DB01217', 'DB01047'), ('DB01217', 'DB00374'), ('DB01217', 'DB00762'), ('DB01217', 'DB00394'), ('DB01217', 'DB00056'), ('DB01217', 'DB00862')],

    [('DB00682', 'DB00445'), ('DB00682', 'DB00479'), ('DB00682', 'DB00455'), ('DB00682', 'DB00307'), ('DB00682', 'DB00258'), ('DB00682', 'DB00391'), ('DB00682', 'DB00487'), ('DB00682', 'DB00006'), ('DB00682', 'DB00197'), ('DB00682', 'DB00620')],
    [('DB00661', 'DB00345'), ('DB00661', 'DB00208'), ('DB00661', 'DB00041'), ('DB00661', 'DB00299'), ('DB00661', 'DB00470'), ('DB00661', 'DB00444'), ('DB00661', 'DB00624'), ('DB00661', 'DB00648'), ('DB00661', 'DB00344'), ('DB00661', 'DB00297')],

]

for l in list_of_predictions:
    count = 0
    for pred in l:
        count += check_known_triplets(pred[0], pred[1], 'drugbank')
    print(count)