## BioKG dataset - subgraph extraction
Extracting drug nodes with coressponding relations, which are also in DrugBank. To obtain new information for learning my KG.  

In [61]:
import pandas as pd
import itertools

import torch
import torch_geometric.transforms as T
import torch_geometric.utils as U

from ogb.linkproppred import PygLinkPropPredDataset

In [62]:
biokg_dataset_dir = '../data/dataset-ogb/'
drugbank_dir = '../data/triplets/'

In [63]:
drugs_drugbank = pd.read_csv(drugbank_dir + 'drug_atc_codes.tsv', sep='\t', index_col=[0])

In [64]:
drugs_drugbank = drugs_drugbank.drop(columns=['relation'])
drugs_drugbank

Unnamed: 0,id,atc_code
1,DB00001,B01AE02
2,DB00002,L01XC06
3,DB00003,R05CB13
4,DB00004,L01XX29
5,DB00005,L04AB01
...,...,...
4015,DB15598,B03AB10
4016,DB15874,A16AB03
4017,DB16024,V04CX01
4018,DB16581,L04AX08


In [65]:
dataset = PygLinkPropPredDataset(name='ogbl-biokg', root=biokg_dataset_dir, transform=T.ToSparseTensor())
data = dataset[0]
data

Data(
  num_nodes_dict={
    disease=10687,
    drug=10533,
    function=45085,
    protein=17499,
    sideeffect=9969
  },
  edge_index_dict={
    (disease, disease-protein, protein)=[2, 73547],
    (drug, drug-disease, disease)=[2, 5147],
    (drug, drug-drug_acquired_metabolic_disease, drug)=[2, 63430],
    (drug, drug-drug_bacterial_infectious_disease, drug)=[2, 18554],
    (drug, drug-drug_benign_neoplasm, drug)=[2, 30348],
    (drug, drug-drug_cancer, drug)=[2, 48514],
    (drug, drug-drug_cardiovascular_system_disease, drug)=[2, 94842],
    (drug, drug-drug_chromosomal_disease, drug)=[2, 316],
    (drug, drug-drug_cognitive_disorder, drug)=[2, 34660],
    (drug, drug-drug_cryptorchidism, drug)=[2, 128],
    (drug, drug-drug_developmental_disorder_of_mental_health, drug)=[2, 14314],
    (drug, drug-drug_endocrine_system_disease, drug)=[2, 55994],
    (drug, drug-drug_fungal_infectious_disease, drug)=[2, 36114],
    (drug, drug-drug_gastrointestinal_system_disease, drug)=[2, 83210

In [66]:
drug_entidx2stitch = pd.read_csv(biokg_dataset_dir + 'ogbl_biokg/mapping/drug_entidx2name.csv', index_col=[0])
drug_entidx2stitch.reset_index(inplace=True)
drug_entidx2stitch

Unnamed: 0,ent idx,ent name
0,0,CID000000001
1,1,CID000000010
2,2,CID000000014
3,3,CID000000015
4,4,CID000000037
...,...,...
10528,10528,CID125880656
10529,10529,CID144462760
10530,10530,CID151508717
10531,10531,CID151601240


#### Mapping STITCH (CID) ids to DrugBank ids or ATC codes
1. download mappings from STITCH database -> keep only ATC codes -> correct CIDs/m -> CID0/1
2. dowload mappings from https://github.com/iit-Demokritos/drug_id_mapping -> compare with mapping obtained from 1. -> if the mapping is correct -> use this one (no further preprocessing needed)

In [67]:
# download http://stitch.embl.de/download/chemical.sources.v5.0.tsv.gz 
# I deleted it because it occupies a lot of memory
# I'm using only part of it - cid2atc

# skip = list(range(9))
# cid2atc = pd.read_csv('../data/chemical.sources.v5.0.tsv', sep='\t', skiprows=skip, names=['cidm', 'cids', 'source', 'id'])
# cid2atc

CIDs / CID0... - this is a stereo-specific compound, and the suffix is the 
PubChem compound id.

CIDm / CID1... - this is a "flat" compound, i.e. with merged stereo-isomers
The suffix (without the leading "1") is the PubChem compound id.

In [68]:
# cid2atc = cid2atc[cid2atc.source == 'ATC']
# cid2atc = cid2atc.drop(columns=['source'])
# cid2atc = cid2atc.rename(columns={'id':'atc'})
# cid2atc['cidm'] = cid2atc['cidm'].str.replace(r'm', r'1')
# cid2atc['cids'] = cid2atc['cids'].str.replace(r's', r'0')
# print(cid2atc.head())
# cid2atc.to_csv('../data/cid2atc.csv')

cid2atc = pd.read_csv('../data/cid2atc.csv', index_col=[0])
cid2atc.sort_values(by=['cids'])

Unnamed: 0,cidm,cids,atc
519,CID100000001,CID000000001,N06BX12
892,CID100000119,CID000000119,L03AA03
2998,CID100000119,CID000000119,N03AG03
1362,CID100000137,CID000000137,L01XD04
1217,CID100000174,CID000000174,A06AD15
...,...,...,...
1063,CID188516236,CID090470007,L04AC03
273,CID190470996,CID090470996,B02BD30
272,CID190470996,CID090470996,B02BC06
1934,CID116139342,CID090472060,A10BX10


In [69]:
# merge cid2atc and drug_entidx2stitch on cidm/s

merged = drug_entidx2stitch.merge(cid2atc, how='inner', left_on='ent name', right_on='cids')
merged
# merged.merge(cid2atc, how='left', left_on='ent name', right_on='cidm')

Unnamed: 0,ent idx,ent name,cidm,cids,atc
0,0,CID000000001,CID100000001,CID000000001,N06BX12
1,17,CID000000119,CID100000119,CID000000119,L03AA03
2,17,CID000000119,CID100000119,CID000000119,N03AG03
3,21,CID000000137,CID100000137,CID000000137,L01XD04
4,39,CID000000187,CID100000187,CID000000187,S01EB09
...,...,...,...,...,...
1710,9217,CID054684141,CID154677977,CID054684141,L04AA31
1711,9218,CID054686350,CID154686350,CID054686350,N02BA03
1712,9224,CID056603701,CID156603701,CID056603701,L04AC06
1713,9264,CID056842117,CID156842117,CID056842117,L01XC06


In [70]:
# https://github.com/iit-Demokritos/drug_id_mapping
# !! cite if I use it !!

drugs_mapping = pd.read_csv('../data/drug-idx-mappings.tsv', sep='\t', index_col=[0])
drugs_mapping = drugs_mapping[['name', 'stitch_id']].dropna()
drugs_mapping.reset_index(inplace=True)
drugs_mapping                         

Unnamed: 0,drugbankId,name,stitch_id
0,DB01101,Capecitabine,CID000060953
1,DB01100,Pimozide,CID000016362
2,DB01105,Sibutramine,CID000005210
3,DB01104,Sertraline,CID100005203
4,DB01109,Heparin,CID153477714
...,...,...,...
1243,DB01058,Praziquantel,CID000004891
1244,DB01059,Norfloxacin,CID000004539
1245,DB01037,Deprenyl,CID000026757
1246,DB00188,PS-341,CID000387447


In [71]:
# check if cid-atc/drugbankID matches in downloaded mapping (drugs_mapping) and cid2atc
drugs_mapping_atc = drugs_mapping.merge(drugs_drugbank, how='left', left_on='drugbankId', right_on='id')
drugs_mapping_atc = drugs_mapping_atc.dropna()

drugs_mapping_atc_cid = drugs_mapping_atc.merge(cid2atc, how='inner', left_on='atc_code', right_on='atc')
drugs_mapping_atc_cid

Unnamed: 0,drugbankId,name,stitch_id,id,atc_code,cidm,cids,atc
0,DB01101,Capecitabine,CID000060953,DB01101,L01BC06,CID100060953,CID000060953,L01BC06
1,DB01100,Pimozide,CID000016362,DB01100,N05AG02,CID100016362,CID000016362,N05AG02
2,DB01105,Sibutramine,CID000005210,DB01105,A08AA10,CID100005210,CID000005210,A08AA10
3,DB01104,Sertraline,CID100005203,DB01104,N06AB06,CID100005203,CID000063009,N06AB06
4,DB01109,Heparin,CID153477714,DB01109,C05BA03,CID100000772,CID000000772,C05BA03
...,...,...,...,...,...,...,...,...
1481,DB01050,Ibuprofen,CID000003672,DB01050,M02AA13,CID100003672,CID000003672,M02AA13
1482,DB01057,Echothiophate,CID100010547,DB01057,S01EB03,CID100010547,CID000010547,S01EB03
1483,DB01058,Praziquantel,CID000004891,DB01058,P02BA01,CID100004891,CID000004891,P02BA01
1484,DB01059,Norfloxacin,CID000004539,DB01059,S01AE02,CID100004539,CID000004539,S01AE02


In [72]:
print(f'Mapping is the same in {(drugs_mapping_atc_cid["stitch_id"] == drugs_mapping_atc_cid["cids"]).sum()} / {drugs_mapping_atc_cid.shape[0]} cases')

not_same = drugs_mapping_atc_cid["stitch_id"] != drugs_mapping_atc_cid["cids"]
drugs_mapping_atc_cid[not_same]

Mapping is the same in 1167 / 1486 cases


Unnamed: 0,drugbankId,name,stitch_id,id,atc_code,cidm,cids,atc
3,DB01104,Sertraline,CID100005203,DB01104,N06AB06,CID100005203,CID000063009,N06AB06
4,DB01109,Heparin,CID153477714,DB01109,C05BA03,CID100000772,CID000000772,C05BA03
5,DB01109,Heparin,CID153477714,DB01109,S01XA14,CID100000772,CID000000772,S01XA14
6,DB01109,Heparin,CID153477714,DB01109,B01AB01,CID100000772,CID000000772,B01AB01
7,DB01106,Levocabastine,CID100003915,DB01106,R01AC02,CID100003915,CID000054384,R01AC02
...,...,...,...,...,...,...,...,...
1455,DB01026,Ketoconazole,CID100003823,DB01026,D01AC08,CID100003823,CID000456201,D01AC08
1457,DB01030,Topotecan,CID100005515,DB01030,L01XX17,CID100005515,CID000060699,L01XX17
1462,DB01036,Tolterodine,CID100005512,DB01036,G04BD07,CID100005512,CID000443878,G04BD07
1477,DB01053,Benzylpenicillin,CID000005904,DB01053,J01CE08,CID100015232,CID000015232,J01CE08


Downloaded mapping (drugs_mapping) is correct (not same rows are same as cidm instead of cids)

In [73]:
stitch2drugbank = drug_entidx2stitch.merge(drugs_mapping, how='inner', left_on='ent name', right_on='stitch_id')
stitch2drugbank = stitch2drugbank.drop(columns=['ent name'])
stitch2drugbank

Unnamed: 0,ent idx,drugbankId,name,stitch_id
0,17,DB02530,gamma-Aminobutyric acid,CID000000119
1,21,DB00855,Aminolevulinic acid,CID000000137
2,39,DB03128,Acetylcholine,CID000000187
3,59,DB06770,Benzyl alcohol,CID000000244
4,61,DB06756,Glycine betaine,CID000000247
...,...,...,...,...
950,10519,DB06817,Raltegravir,CID123668479
951,10520,DB00554,Piroxicam,CID123690938
952,10526,DB00050,Cetrorelix,CID125074886
953,10530,DB00469,Tenoxicam,CID151508717


### Selecting indices in BioKG 
Select only indices whose drugbankID is also in my data (drugs_drugbank)

In [None]:
idx2keep = stitch2drugbank.merge(drugs_drugbank, how='inner', left_on='drugbankId', right_on='id')
idx2keep = idx2keep.drop_duplicates(subset=['ent idx'])
drugs2keep = torch.tensor(list(idx2keep['ent idx']))
drugs2keep

#### Select the subgraph

1. Select subgraph with drug_ids in drugs2keep
    - relations: drug - * - drug, drug - * - *, * - * - drug, where drug in drugs2keep
2. Add relations node_type1 - * - node_type2 to the subgraph, where both node_type1 and node_type2 were selected in 1., and node_type1 and node_type 2 can be same or different node type  

In [75]:
idx2drugbank = idx2keep[['ent idx', 'drugbankId']].set_index('ent idx')['drugbankId'].to_dict()
idx2drugbank

{21: 'DB00855',
 39: 'DB03128',
 59: 'DB06770',
 61: 'DB06756',
 73: 'DB09278',
 78: 'DB04272',
 84: 'DB00936',
 103: 'DB01156',
 121: 'DB00513',
 126: 'DB09110',
 132: 'DB04398',
 152: 'DB01093',
 153: 'DB00988',
 162: 'DB00431',
 169: 'DB00145',
 170: 'DB09462',
 176: 'DB05381',
 180: 'DB09526',
 211: 'DB01065',
 227: 'DB06690',
 246: 'DB00339',
 250: 'DB00165',
 272: 'DB00152',
 360: 'DB06637',
 365: 'DB00252',
 394: 'DB01193',
 396: 'DB00316',
 397: 'DB00819',
 399: 'DB00551',
 400: 'DB06709',
 405: 'DB00787',
 411: 'DB00518',
 412: 'DB01001',
 414: 'DB00630',
 416: 'DB00346',
 418: 'DB00437',
 420: 'DB00969',
 424: 'DB00404',
 426: 'DB00488',
 427: 'DB00915',
 432: 'DB01143',
 434: 'DB00357',
 435: 'DB00345',
 437: 'DB00277',
 440: 'DB06288',
 441: 'DB00321',
 442: 'DB01025',
 443: 'DB00381',
 444: 'DB00543',
 449: 'DB00276',
 451: 'DB00261',
 452: 'DB01217',
 458: 'DB00964',
 465: 'DB00945',
 469: 'DB00335',
 471: 'DB00993',
 472: 'DB00548',
 473: 'DB00972',
 479: 'DB00181',
 487

In [115]:
"""
Replace values in the given tensor l according to mapping.
"""
def replace(l, mapping):
    return [mapping[int(val)] for val in l]

def replace_with_type(l, _type):
    return [_type + '_' +str(int(val)) for val in l]

In [106]:
"""
Filter the given 2D tensor t according to the first dim (t[0]) -> keeps just values which are in keep -> create mask
Returns 2D tensor with values according to the mask (applied in both dims).

Params:
    t - 2D tensor
    keep - 1D tensor, what values to keep
"""
def my_filter(t, keep):
    mask = torch.isin(t[0],keep) 
    return t[0][mask], t[1][mask]

In [77]:
# select only drug - drug relations

relations = list(data.edge_index_dict.keys())
drug_relations = [rel for rel in relations if rel[0] == 'drug' and rel[2] == 'drug']

result = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])

for i in range(len(drug_relations)):
    edge_index = data.edge_index_dict[drug_relations[i]]
    max_idx = data.edge_index_dict[drug_relations[i]].max()
    sub_rel = U.subgraph(drugs2keep[drugs2keep < max_idx], edge_index)

    drug1_ids = replace(sub_rel[0][0], idx2drugbank)
    drug2_ids = replace(sub_rel[0][1], idx2drugbank)
    
    tmp_df = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])
    tmp_df['drug1'] = drug1_ids
    tmp_df['drug2'] = drug2_ids
    tmp_df['relation'] = list(itertools.repeat(drug_relations[i][1], len(drug1_ids)))
    
    result = pd.concat([result, tmp_df])

In [81]:
result

Unnamed: 0,drug1,relation,drug2
0,DB00996,drug-drug_acquired_metabolic_disease,DB00549
1,DB00395,drug-drug_acquired_metabolic_disease,DB00916
2,DB00555,drug-drug_acquired_metabolic_disease,DB00966
3,DB00321,drug-drug_acquired_metabolic_disease,DB00222
4,DB00313,drug-drug_acquired_metabolic_disease,DB01595
...,...,...,...
10639,DB00918,drug-drug_viral_infectious_disease,DB00996
10640,DB00502,drug-drug_viral_infectious_disease,DB00575
10641,DB01016,drug-drug_viral_infectious_disease,DB01023
10642,DB00787,drug-drug_viral_infectious_disease,DB01156


In [138]:
# select drug - * relations

other_relations = [rel for rel in relations if rel[0] == 'drug' and rel[2] != 'drug']

result_hetero = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])

for i in range(len(other_relations)):
    edge_index = data.edge_index_dict[other_relations[i]]
    sub_rel = my_filter(edge_index, drugs2keep)
    
    drug1_ids = replace(sub_rel[0], idx2drugbank)
    drug2_ids = replace_with_type(sub_rel[1], other_relations[i][2])
    
    tmp_df = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])
    tmp_df['drug1'] = drug1_ids
    tmp_df['drug2'] = drug2_ids
    tmp_df['relation'] = list(itertools.repeat(other_relations[i][1], len(drug1_ids)))
    
    result_hetero = pd.concat([result_hetero, tmp_df])
    
    
result_hetero    

Unnamed: 0,drug1,relation,drug2
0,DB00268,drug-disease,disease_1402
1,DB00531,drug-disease,disease_3573
2,DB00501,drug-disease,disease_758
3,DB01165,drug-disease,disease_5249
4,DB00541,drug-disease,disease_1954
...,...,...,...
82961,DB00648,drug-sideeffect,sideeffect_5579
82962,DB01162,drug-sideeffect,sideeffect_7808
82963,DB01097,drug-sideeffect,sideeffect_2204
82964,DB01595,drug-sideeffect,sideeffect_1649


In [161]:
# select * - * relations (both * must be in result_hetero.drug2)
# protein - protein, disease - protein

proteins = list(result_hetero.drug2[result_hetero.drug2.str.match('protein.*')])
proteins = torch.tensor([int(val.split('_')[1]) for val in proteins])

protein_relations = [rel for rel in relations if rel[0] == 'protein' and rel[2] == 'protein']

result_protein = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])

for i in range(len(protein_relations)):
    edge_index = data.edge_index_dict[protein_relations[i]]
    sub_rel = my_filter(edge_index, proteins)
    
    drug1_ids = replace_with_type(sub_rel[0], protein_relations[i][0])
    drug2_ids = replace_with_type(sub_rel[1], protein_relations[i][2])
    
    tmp_df = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])
    tmp_df['drug1'] = drug1_ids
    tmp_df['drug2'] = drug2_ids
    tmp_df['relation'] = list(itertools.repeat(protein_relations[i][1], len(drug1_ids)))
    
    result_protein = pd.concat([result_protein, tmp_df])
    
result_protein    

Unnamed: 0,drug1,relation,drug2
0,protein_3772,protein-protein_activation,protein_9842
1,protein_6205,protein-protein_activation,protein_7528
2,protein_16631,protein-protein_activation,protein_10469
3,protein_7387,protein-protein_activation,protein_1524
4,protein_3171,protein-protein_activation,protein_7112
...,...,...,...
145493,protein_6773,protein-protein_reaction,protein_6169
145494,protein_16220,protein-protein_reaction,protein_7853
145495,protein_3774,protein-protein_reaction,protein_17097
145496,protein_12598,protein-protein_reaction,protein_133


In [162]:
diseases = list(result_hetero.drug2[result_hetero.drug2.str.match('disease.*')])
diseases = torch.tensor([int(val.split('_')[1]) for val in diseases])

diseases_relations = [rel for rel in relations if rel[0] == 'disease' and rel[2] == 'protein']

result_disease = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])

for i in range(len(diseases_relations)):
    edge_index = data.edge_index_dict[diseases_relations[i]]
    sub_rel = my_filter(edge_index, diseases)
    
    drug1_ids = replace_with_type(sub_rel[0], diseases_relations[i][0])
    drug2_ids = replace_with_type(sub_rel[1], diseases_relations[i][2])
    
    tmp_df = pd.DataFrame(columns=['drug1', 'relation', 'drug2'])
    tmp_df['drug1'] = drug1_ids
    tmp_df['drug2'] = drug2_ids
    tmp_df['relation'] = list(itertools.repeat(diseases_relations[i][1], len(drug1_ids)))
    
    result_disease = pd.concat([result_disease, tmp_df])
    
result_disease 

Unnamed: 0,drug1,relation,drug2
0,disease_1614,disease-protein,protein_3358
1,disease_1496,disease-protein,protein_1350
2,disease_3382,disease-protein,protein_15142
3,disease_4438,disease-protein,protein_17390
4,disease_375,disease-protein,protein_16221
...,...,...,...
14013,disease_1614,disease-protein,protein_5635
14014,disease_4435,disease-protein,protein_17492
14015,disease_3856,disease-protein,protein_5710
14016,disease_213,disease-protein,protein_10075


In [164]:
# concatenate everything

biokg_subgraph = pd.concat([result, result_hetero, result_protein, result_disease])
biokg_subgraph.to_csv(drugbank_dir + 'biokg_subgraph.csv')