# Benchmarking for different dataset and methods

### Benchmark dataset description
**Full-DTIs-LC-Benchmark**
A benchmark dataset containing Drug-Target Interation (DTI) data of Lung Cancer (LC). This dataset avoided information that is only generated automatically, through text mining, and focused on most trustworthy sources, namely DrugBank, KEGG, DGIdb and TTD data. The union of DrugBank, KEGG, DGIdb and TTD provided 44,169 positive drug-gene interactions in total, with 1931 of those related on one side (drug) or the other (gene) to Lung Cancer (LC). As for the negative drug-gene pairs, there are 627,971 pairs, for which no interaction is reported in any of the above databases.

- [Paper](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-023-05373-2)
- [Code](https://github.com/fotais/drug-gene-interactions/tree/main)
- [Data](https://github.com/fotais/drug-gene-interactions/blob/main/Full-DTIs-LC-Benchmark.csv)

*Disease: Lung cancer MESH:D008175*


In [1]:
from benchmark import benchmark
lc = benchmark()
lc.add_dataset(
    disease='MESH:D008175', 
    dataset='Full-DTIs-LC-Benchmark.csv',
    activity_col=4,
    header=True
)

# Statistics of the benchmark dataset
print(f'Total disease: {lc.n_disease}')
print(f'Total compounds: {lc.n_compound}')
print(f'Total positive relations: {lc.n_pos_relation}')
print(f'Total negative relations: {lc.n_neg_relation}')

Total disease: 1
Total compounds: 9687
Total positive relations: 9677
Total negative relations: 9687


## Add evaluation metrics
Here, MRR is the default metric.
If you want to add new evaluation metrics, please define a function containing the following arguments:
- _predicts:dict_, prediction results
- _positives:dict_, positive relationships of benchmark dataset
- _negatives:dict_, negative relationships of benchmark dataset
- other arguments if needed.

When call **lc.evaluate** with the prediction results, _positives_ and _negatives_ are automatically passed by benchmark.

As an example, we added hits@k for k=10,000.

In [2]:
from benchmark_metrics import MRR, hitsk
lc.add_metric('hits@10k', hitsk, k=10000)
lc.metrics

{'MRR': {'func': <function benchmark_metrics.MRR(predicts: dict, positives: dict, negatives: dict)>,
  'kwargs': {}},
 'hits@10k': {'func': <function benchmark_metrics.hitsk(predicts: dict, positives: dict, negatives: dict, k=10000)>,
  'kwargs': {'k': 10000}}}

## Perform prediction
Here, we use DRKG to do a prediction

In [3]:
import pandas as pd
import csv
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Disease
disease_list = ['Disease::MESH:D008175']

# Treatments
treatments = [
    'DRUGBANK::treats::Compound:Disease', 
    'Hetionet::CtD::Compound:Disease',
    'GNBR::T::Compound:Disease'
]

# Drug candidates
# Use all drugbank candidates from DRKG
entities = '../models/drkg/data/entity2src.tsv'
drug_cands = set()
with open(entities) as fr:
    for l in fr:
        ws = l.rstrip().split('\t')
        if ws[0].startswith('Compound::DB'):
            drug_cands.add(ws[0].replace('Compound::', ''))


# load embeddings
data_dir = '../models/drkg/data'
entity_idmap_file = f'{data_dir}/embed/entities.tsv'
relation_idmap_file = f'{data_dir}/embed/relations.tsv'
entity_emb = np.load(f'{data_dir}/embed/DRKG_TransE_l2_entity.npy')
rel_emb = np.load(f'{data_dir}/embed/DRKG_TransE_l2_relation.npy')

# Get drugname/disease name to entity ID mappings
entity_map = {}
entity_id_map = {}
relation_map = {}
with open(entity_idmap_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['name','id'])
    for row_val in reader:
        entity_map[row_val['name']] = int(row_val['id'])
        entity_id_map[int(row_val['id'])] = row_val['name']
        
with open(relation_idmap_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['name','id'])
    for row_val in reader:
        relation_map[row_val['name']] = int(row_val['id'])
        
# handle the ID mapping
drug_ids = []
disease_ids = []
for drug in drug_cands:
    drug_name=f'Compound::{drug}'
    if drug_name in entity_map:
        drug_ids.append(entity_map[drug_name])
    
for disease in disease_list:
    if disease in entity_map:
        disease_ids.append(entity_map[disease])

treatment_rid = [relation_map[treat]  for treat in treatments]

print(f'Total available disease: {len(disease_ids)}')
print(f'Total available drugs: {len(drug_ids)}')
print(f'Total available relations: {len(treatment_rid)}')

Total available disease: 1
Total available drugs: 10551
Total available relations: 3


In [4]:
# Load embeddings
drug_ids = torch.tensor(drug_ids).long()
disease_ids = torch.tensor(disease_ids).long()
treatment_rid = torch.tensor(treatment_rid)

drug_emb = torch.tensor(entity_emb[drug_ids])
treatment_embs = [torch.tensor(rel_emb[rid]) for rid in treatment_rid]

print(drug_emb.size())

torch.Size([10551, 400])


In [5]:
import torch.nn.functional as fn

gamma=12.0
def transE_l2(head, rel, tail):
    score = head + rel - tail
    return gamma - torch.norm(score, p=2, dim=-1)

scores_per_disease = []
dids = []
for rid in range(len(treatment_embs)):
    treatment_emb=treatment_embs[rid]
    for disease_id in disease_ids:
        disease_emb = entity_emb[disease_id]
        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))
        scores_per_disease.append(score)
        dids.append(drug_ids)
scores = torch.cat(scores_per_disease)
dids = torch.cat(dids)

# sort scores in decending order
idx = torch.flip(torch.argsort(scores), dims=[0])
scores = scores[idx].numpy()
dids = dids[idx].numpy()

results = []
for rid in range(len(treatment_embs)):
    treatment_emb=treatment_embs[rid]
    for disease_id in disease_ids:
        disease_emb = entity_emb[disease_id]
        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))
        for i in range(len(score)):
            results.append([
                entity_id_map[int(drug_ids[i])].split('::')[-1], 
                float(score[i]), 
                treatments[rid], 
                entity_id_map[int(disease_id)]
            ])
results = pd.DataFrame(results, columns=['DrugBankID', 'Score', 'Treatment', 'Disease_id'])
results

Unnamed: 0,DrugBankID,Score,Treatment,Disease_id
0,DB06436,-0.952330,DRUGBANK::treats::Compound:Disease,Disease::MESH:D008175
1,DB00432,-0.351053,DRUGBANK::treats::Compound:Disease,Disease::MESH:D008175
2,DB12734,-2.203085,DRUGBANK::treats::Compound:Disease,Disease::MESH:D008175
3,DB03160,-3.012343,DRUGBANK::treats::Compound:Disease,Disease::MESH:D008175
4,DB07981,-3.079597,DRUGBANK::treats::Compound:Disease,Disease::MESH:D008175
...,...,...,...,...
31648,DB13987,-0.502823,GNBR::T::Compound:Disease,Disease::MESH:D008175
31649,DB12110,-1.251546,GNBR::T::Compound:Disease,Disease::MESH:D008175
31650,DB02597,-2.687775,GNBR::T::Compound:Disease,Disease::MESH:D008175
31651,DB03311,-2.766237,GNBR::T::Compound:Disease,Disease::MESH:D008175


For benchmarking, we format the prediction results to a dict

In [6]:
results_grouped = results.groupby(['Disease_id', 'DrugBankID']).Score.min().sort_values()
predicts = {}
for disease, drug in results_grouped.index:
    disease = disease.replace('Disease::', '')
    if disease not in predicts:
        predicts[disease] = []
    predicts[disease].append(drug)
predicts

{'MESH:D008175': ['DB15213',
  'DB14777',
  'DB15245',
  'DB14051',
  'DB14058',
  'DB14752',
  'DB14337',
  'DB14569',
  'DB14052',
  'DB14733',
  'DB14579',
  'DB13704',
  'DB14484',
  'DB15595',
  'DB14521',
  'DB14725',
  'DB15559',
  'DB15271',
  'DB11398',
  'DB01906',
  'DB15167',
  'DB15193',
  'DB04015',
  'DB13571',
  'DB13289',
  'DB15192',
  'DB15415',
  'DB13355',
  'DB04728',
  'DB01642',
  'DB13824',
  'DB08161',
  'DB15336',
  'DB15528',
  'DB03841',
  'DB15354',
  'DB08668',
  'DB02119',
  'DB14930',
  'DB06906',
  'DB13360',
  'DB13297',
  'DB04112',
  'DB02151',
  'DB13788',
  'DB01460',
  'DB15300',
  'DB15409',
  'DB04580',
  'DB04496',
  'DB13603',
  'DB02599',
  'DB04488',
  'DB01491',
  'DB12123',
  'DB11474',
  'DB15383',
  'DB08328',
  'DB12780',
  'DB13459',
  'DB15397',
  'DB02272',
  'DB07808',
  'DB03799',
  'DB15206',
  'DB03592',
  'DB13753',
  'DB14093',
  'DB15349',
  'DB03584',
  'DB15492',
  'DB13583',
  'DB08757',
  'DB13361',
  'DB07704',
  'DB0748

## Evaluation
To evaluation, simply pass the formatted prediction dict to the evaluate function

In [7]:
lc.evaluate(predicts)

{'MRR': 0.0005905747018456806, 'hits@10k': 0.859357238813682}