## Predict the drug activity against ME/CFS

### Disease

**Myalgic encephalomyelitis / Fatigue Syndrome, chronic (ME/CFS)**, A syndrome characterized by persistent or recurrent fatigue, diffuse musculoskeletal pain, sleep disturbances, and subjective cognitive impairment of 6 months duration or longer. Symptoms are not caused by ongoing exertion; are not relieved by rest; and result in a substantial reduction of previous levels of occupational, educational, social, or personal activities. Minor alterations of immune, neuroendocrine, and autonomic function may be associated with this syndrome. There is also considerable overlap between this condition and FIBROMYALGIA. (From Semin Neurol 1998;18(2):237-42; Ann Intern Med 1994 Dec 15;121(12): 953-9)

- [MESH: D015673](https://meshb.nlm.nih.gov/record/ui?ui=D015673)
- [DOID: 8544](https://disease-ontology.org/do)

**Long covid**, Post acute stage of COVID-19 virus infection. Persistent symptoms may include FATIGUE; DYSPNEA; and MEMORY LOSS.
*Does not exist in original DRKG*

- [MESH: D000094024](https://www.ncbi.nlm.nih.gov/mesh/?term=long+covid)
- [DOID: 0080848](https://disease-ontology.org/do)

### Drug candidates
Potential drug candidates are listed in **deduplicated_drugs.txt**

In [16]:
import pandas as pd
import csv
import torch
import numpy as np

## Get DRKG

In [None]:
%%bash
export DATA_DIR="../models/drkg/data"
export MODEL_DIR="../models/drkg/models"
mkdir -p ${DATA_DIR} ${MODEL_DIR}
wget https://s3.us-west-2.amazonaws.com/dgl-data/dataset/DRKG/drkg.tar.gz -O ${DATA_DIR}/drkg.tar.gz
tar -xvzf ${DATA_DIR}/drkg.tar.gz -C ${DATA_DIR}

## Disease list

In [2]:
disease_list = [
    'Disease::MESH:D015673',
    'Disease::DOID:8544'
]

## Drug candidate list

In [6]:
candidate_file = 'deduplicated_drugs.txt'
drug_cands = pd.read_csv(candidate_file, sep='\t')
drug_cands

Unnamed: 0,DrugBankID,DrugName,Category
0,DB00001,Lepirudin,
1,DB00005,Etanercept,
2,DB00006,Bivalirudin,
3,DB00009,Alteplase,
4,DB00013,Urokinase,
...,...,...,...
808,DB17289,L-arginine +/- L-citrulline,Vasodilating
809,DB17508,Hafnium oxide,
810,DB17614,Ergothioneine,Mushroom Derivatives
811,DB17735,Ajoene,


## Treatment relation

In [7]:
treatment = ['Hetionet::CtD::Compound:Disease','GNBR::T::Compound:Disease']

## Load embeddings of disease and drugs

In [17]:
data_dir = '../models/drkg/data'
entity_idmap_file = f'{data_dir}/embed/entities.tsv'
relation_idmap_file = f'{data_dir}/embed/relations.tsv'
entity_emb = np.load(f'{data_dir}/embed/DRKG_TransE_l2_entity.npy')
rel_emb = np.load(f'{data_dir}/embed/DRKG_TransE_l2_relation.npy')


In [27]:
# Get drugname/disease name to entity ID mappings
entity_map = {}
entity_id_map = {}
relation_map = {}
with open(entity_idmap_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['name','id'])
    for row_val in reader:
        entity_map[row_val['name']] = int(row_val['id'])
        entity_id_map[int(row_val['id'])] = row_val['name']
        
with open(relation_idmap_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['name','id'])
    for row_val in reader:
        relation_map[row_val['name']] = int(row_val['id'])
        
# handle the ID mapping
drug_ids = []
disease_ids = []
for _, drug in drug_cands.iterrows():
    drug_name=f'Compound::{drug.DrugBankID}'
    if drug_name in entity_map:
        drug_ids.append(entity_map[drug_name])
    
for disease in disease_list:
    if disease in entity_map:
        disease_ids.append(entity_map[disease])

treatment_rid = [relation_map[treat]  for treat in treatment]

print(f'Total available disease: {len(disease_ids)}')
print(f'Total available drugs: {len(drug_ids)}')
print(f'Total available relations: {len(treatment_rid)}')

Total available disease: 1
Total available drugs: 727
Total available relations: 2


In [28]:
# Load embeddings
drug_ids = torch.tensor(drug_ids).long()
disease_ids = torch.tensor(disease_ids).long()
treatment_rid = torch.tensor(treatment_rid)

drug_emb = torch.tensor(entity_emb[drug_ids])
treatment_embs = [torch.tensor(rel_emb[rid]) for rid in treatment_rid]

print(drug_emb.size())

torch.Size([727, 400])


## Predict edge score
Here, we use L2 loss of pretrained transE model

In [50]:
import torch.nn.functional as fn

gamma=12.0
def transE_l2(head, rel, tail):
    score = head + rel - tail
    return gamma - torch.norm(score, p=2, dim=-1)

scores_per_disease = []
dids = []
for rid in range(len(treatment_embs)):
    treatment_emb=treatment_embs[rid]
    for disease_id in disease_ids:
        disease_emb = entity_emb[disease_id]
        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))
        scores_per_disease.append(score)
        dids.append(drug_ids)
scores = torch.cat(scores_per_disease)
dids = torch.cat(dids)

# sort scores in decending order
idx = torch.flip(torch.argsort(scores), dims=[0])
scores = scores[idx].numpy()
dids = dids[idx].numpy()


In [51]:
results = []
for rid in range(len(treatment_embs)):
    treatment_emb=treatment_embs[rid]
    for disease_id in disease_ids:
        disease_emb = entity_emb[disease_id]
        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))
        for i in range(len(score)):
            results.append([
                entity_id_map[int(drug_ids[i])].split('::')[-1], 
                float(score[i]), 
                treatment[rid], 
                entity_id_map[int(disease_id)]
            ])
results = pd.DataFrame(results, columns=['DrugBankID', 'Score', 'Treatment', 'Disease_id'])
results

Unnamed: 0,DrugBankID,Score,Treatment,Disease_id
0,DB00001,-5.067222,Hetionet::CtD::Compound:Disease,Disease::MESH:D015673
1,DB00005,-4.401855,Hetionet::CtD::Compound:Disease,Disease::MESH:D015673
2,DB00006,-5.267787,Hetionet::CtD::Compound:Disease,Disease::MESH:D015673
3,DB00009,-5.116919,Hetionet::CtD::Compound:Disease,Disease::MESH:D015673
4,DB00013,-5.128342,Hetionet::CtD::Compound:Disease,Disease::MESH:D015673
...,...,...,...,...
1449,DB15091,-2.734260,GNBR::T::Compound:Disease,Disease::MESH:D015673
1450,DB15536,-1.890350,GNBR::T::Compound:Disease,Disease::MESH:D015673
1451,DB15566,-3.139803,GNBR::T::Compound:Disease,Disease::MESH:D015673
1452,DB15584,-1.230093,GNBR::T::Compound:Disease,Disease::MESH:D015673


In [32]:
# Merge back to results
results = results.merge(drug_cands, on='DrugBankID', how='left')

Compound::DB01065	-0.16302219033241272
Compound::DB00472	-0.18183228373527527
Compound::DB00898	-0.22857537865638733
Compound::DB00787	-0.23830386996269226
Compound::DB00575	-0.25220105051994324
Compound::DB00715	-0.25349968671798706
Compound::DB00741	-0.2603774070739746
Compound::DB14128	-0.26912784576416016
Compound::DB01104	-0.27237817645072937
Compound::DB00126	-0.28711453080177307
Compound::DB00624	-0.3165704607963562
Compound::DB00458	-0.3234078586101532
Compound::DB00915	-0.323972225189209
Compound::DB01151	-0.3281634449958801
Compound::DB12116	-0.33403462171554565
Compound::DB00627	-0.3344188332557678
Compound::DB00182	-0.3382665514945984
Compound::DB00184	-0.3432280719280243
Compound::DB01576	-0.3432931900024414
Compound::DB00313	-0.3471139073371887
Compound::DB00215	-0.35058334469795227
Compound::DB00159	-0.35303106904029846
Compound::DB00363	-0.3551667332649231
Compound::DB00201	-0.3563298285007477
Compound::DB01234	-0.35930660367012024
Compound::DB00331	-0.3648461103439331
