In [1]:
from tqdm.notebook import tqdm
import re
import os 
import shutil
import numpy as np
import pandas as pd
import igraph as ig
from scipy.sparse import lil_matrix, save_npz
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
%load_ext autoreload
%autoreload 2

data_path = '../../datasets/'
save_path = data_path+'kg/'

In [2]:
nodes = pd.read_csv(save_path+'nodes.csv', low_memory=False)
edges = pd.read_csv(save_path+'edges.csv', low_memory=False)
kg = pd.read_csv(save_path+'kg.csv', low_memory=False)

## Drug Features

In [3]:
drugs = nodes.query('node_type=="drug"')
drug_features = drugs.copy()

### Drugbank

In [4]:
data = [('atc_1', data_path+'drugbank/drug_features/atc_1.csv'),
        ('atc_2', data_path+'drugbank/drug_features/atc_2.csv'),
        ('atc_3', data_path+'drugbank/drug_features/atc_3.csv'),
        ('atc_4', data_path+'drugbank/drug_features/atc_4.csv'),
        ('Description', data_path+'drugbank/drug_features/description.csv'),
        ('Drugcat', data_path+'drugbank/drug_features/drugcat.csv'),
        ('Group', data_path+'drugbank/drug_features/group.csv'),
        ('Half-life', data_path+'drugbank/drug_features/Hl.csv'),
        ('Indication', data_path+'drugbank/drug_features/indication.csv'),
        ('MOA', data_path+'drugbank/drug_features/moa.csv'),
        ('Pathway', data_path+'drugbank/drug_features/pathway.csv'),
        ('pb', data_path+'drugbank/drug_features/pb.csv'),
        ('Pharm', data_path+'drugbank/drug_features/pharm.csv'),
        ('State', data_path+'drugbank/drug_features/state.csv')]

unique = []
multiple = []

for name, pth in data: 
    feat = pd.read_csv(pth)
    #display(feat.head())
    if pd.merge(drugs, feat.rename(columns={"ID":'node_id'})).get(['node_id', name]).groupby('node_id').count().rename(columns={name:'count'}).query('count>1').empty: 
        #print('Feature is unique')
        unique.append((name,pth))
    else: 
        #print('Contains multiple entries for each drug')
        multiple.append((name,pth))

In [5]:
for name, pth in unique: 
    feat = pd.read_csv(pth)
    drug_features = pd.merge(drug_features, feat.rename(columns={"ID":'node_id'}), 'left')
    
def group(feat, name, join_word=' ; '): 
    id_col = 'ID'
    feat_grouped = []
    for _, group in feat.groupby(id_col): 

        node_idx = group.get(id_col).values[0]

        if group.dropna().shape[0] == 0: # no features    
            f = ''
        elif group.shape[0]>1: # multiple features
            f = join_word.join(group.get(name).values)
        else: # one only 
            f = group.get(name).values[0]

        assert isinstance(f, str)

        feat_grouped.append((node_idx, f))

    return pd.DataFrame(feat_grouped, columns=[id_col, name])
    

for name, pth in multiple:
    feat = pd.read_csv(pth)
    if name == 'Group' or 'atc' in name: 
        join_word = ' and '
    else: 
        join_word = ' ; '
    if 'atc' in name: 
        feat.loc[:, name] = feat.get(name).str.lower()
    if name == 'Pathway': 
        feat[name] = np.array([x.split('|')[0] for x in feat.get(name).values])
    feat = group(feat=feat, name=name, join_word=join_word)
    drug_features = pd.merge(drug_features, feat.rename(columns={"ID":'node_id'}), 'left')

### Drug Central

In [6]:
dc = pd.read_csv(data_path+'drugcentral/dc_features.csv')
dc = dc.get(['id', 'MW', 'TPSA', 'CLOGP']).round(2).drop_duplicates().fillna('').astype('str')
dc = dc.rename(columns={'id':'node_id'})

In [7]:
drug_features = pd.merge(drug_features, dc, 'left', on='node_id')

### Processing

In [8]:
drug_features = drug_features.drop_duplicates().fillna('').astype('str')

drug_features = drug_features.rename(columns={'Description':'description', 'Half-life':'half_life', 
                                  'Indication':'indication', 'MOA':'mechanism_of_action', 'pb':'protein_binding',
                                  'Pharm':'pharmacodynamics', 'State':'state', 'Drugcat':'category', 
                                  'Group':'group', 'Pathway':"pathway", 'MW':'molecular_weight', 
                                  'TPSA':'tpsa', 'CLOGP':'clogp'})

In [9]:
drug_features.fillna('', inplace=True)

# remove "[L64839]" type tokens from all text 
for c in drug_features.columns: 
    new = []
    for x in drug_features.get(c).values.astype('str'): 
        x = re.sub('\[(.*?)\]', '', x)
        x = re.sub('  ', ' ', x)
        x = re.sub(' \.', '.', x)
        new.append(x)
    drug_features.loc[:, c] = np.array(new)

In [10]:
# half life is irregular feature with text but not sentences. probably useful for dosage only. 
hl = drug_features.get(['half_life']).values.reshape(-1)

# remove sentences with no data available
no_data_idx = np.array([i for i, x in enumerate(hl) if (('data' in x or 'Data' in x) and (('not' in x or 'Not' in x) or ('no' in x or 'No' in x)))])
hl[no_data_idx] = ''

# add "half life" where sentence fragments 
add_hl_idx = np.array([i for i, x in enumerate(hl) if 'half life' not in x and 'half-life' not in x and len(x.split(' '))<7 and x!=''])
for i in add_hl_idx: 
    s = hl[i]
    hl[i] = 'The half-life is '+ s[0].lower() + s[1:]
    
drug_features.loc[:, 'half_life'] = hl

In [11]:
# convert binary and single word features into text 

def word2sent(feature, intro, end='.', names=None, name_token = '<name>'): 
    feature = feature.astype(str)
    n = feature.shape[0]
    
    if name_token in intro: 
        assert names is not None
        a = []
        for part in intro.split(name_token):
            if part == '': 
                a.append(names.astype(str))
            else: 
                a.append(np.char.array([part]*n))
    else: 
        a = [np.char.array([intro]*n)]
        
    a.extend([feature, np.char.array([end]*n)])
    f = a[0]
    for i in a[1:]: 
        f = f.astype(str) + i.astype(str)
        
    f = np.where(feature == '', '', f)
    
    return f   

drug_names = np.char.array(drug_features.get(['node_name']).values.reshape(-1).astype(str))

f_name = 'state'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> is a ', names=drug_names)

f_name = 'category'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> is part of ', names=drug_names)

f_name = 'group'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> is ', names=drug_names)

f_name = 'pathway'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> uses ', names=drug_names)

f_name = 'molecular_weight'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='The molecular weight is ', names=drug_names)

f_name = 'tpsa'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> has a topological polar surface area of ', names=drug_names)

f_name = 'clogp'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='The log p value of <name> is ', names=drug_names)

f_name = 'atc_1'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> is anatomically related to ', names=drug_names)

f_name = 'atc_2'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> is in the therapeutic group of ', names=drug_names)

f_name = 'atc_3'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='<name> is pharmacologically related to ', names=drug_names)

f_name = 'atc_4'
f = drug_features.get([f_name]).values.reshape(-1)
drug_features.loc[:, f_name] = word2sent(feature=f, intro='The chemical and functional group of <name> is ', names=drug_names)

In [12]:
drug_features = drug_features.drop(['node_id', 'node_type', 'node_name', 'node_source'], axis=1)

In [13]:
drug_features.to_csv(save_path+'drug_features.csv', index=False)

In [14]:
drug_features = drug_features.replace('', float('nan'))
drug_features_stats = drug_features.describe(include='all').loc[['count','unique'],:].T
c = 'percent_covered'
drug_features_stats.loc[:, c] = 100*drug_features_stats.get('count').values.reshape(-1)/drug_features.shape[0]
drug_features_stats.loc[:, c] = drug_features_stats.get([c]).astype(float).round(1)
drug_features_stats.sort_values('count', ascending=False)

Unnamed: 0,count,unique,percent_covered
node_index,7957,7957,100.0
group,7957,7903,100.0
state,6517,6463,81.9
category,5431,5431,68.3
description,4591,4565,57.7
indication,3393,3076,42.6
mechanism_of_action,3242,3161,40.7
atc_4,2818,1040,35.4
atc_3,2818,2818,35.4
atc_2,2818,2818,35.4


## Disease features

In [15]:
kg_diseases = nodes.query('node_type=="disease"')
disease_map = pd.read_csv(data_path+'kg/auxillary/kg_grouped_diseases_bert_map.csv').astype(str)

x = kg_diseases.query('node_source!="MONDO_grouped"').get(['node_index', 'node_id', 'node_name'])\
.rename(columns={'node_id':'mondo_id', 'node_name':'mondo_name'})
x['group_id_bert'] = ''
x['group_name_bert'] = ''

y = pd.merge(kg_diseases.query('node_source=="MONDO_grouped"').get(['node_index','node_id']).rename(columns={'node_id':'group_id_bert'}), 
         disease_map, 'outer').get(['node_index', 'group_id_bert', 'group_name_bert', 'node_id', 'node_name'])\
.rename(columns={'node_id':'mondo_id', 'node_name':'mondo_name'})

diseases = pd.concat([x,y]).reset_index().drop('index', axis=1)
disease_features = diseases.copy()

### Mondo definitions

In [16]:
mondo_def = pd.read_csv(data_path+'mondo/mondo_definitions.csv').astype(str)
mondo_def = mondo_def.get(['id','definition']).rename(columns={'id':'mondo_id', 'definition':'mondo_definition'})
disease_features = pd.merge(disease_features, mondo_def, 'left').fillna('')

### UMLS

In [17]:
umls1 = pd.read_csv(data_path+'umls/umls_def_disease_2021.csv')
umls2 = pd.read_csv(data_path+'umls/umls_def_disorder_2021.csv')
umls = pd.concat([umls1,umls2]).drop('source',axis=1).drop_duplicates()
mondo_umls_map = pd.read_csv('../../datasets/vocab/umls_mondo.csv')
umls = pd.merge(mondo_umls_map, umls, 'inner', left_on='umls_id', right_on='CUI')
umls = umls.drop(['umls_id','CUI'],axis=1).drop_duplicates().rename(columns={'description':'umls_description'})

def fix_format(x): 
    s = []
    for word in x.split(' '):
        if len(word)>1 and word.isupper(): 
            s.append(word.lower())
        else: 
            s.append(word)
    s = " ".join(s)

    s = re.sub('\[.*?\]','',s)
    s = re.sub('\<.*?\>','',s)
    s = re.sub('\(.*?\)','',s)
    s = s.replace('  ', ' ')
    s = s.replace('  ', ' ')
    s = s.replace(' .', '')
    return s

umls.loc[:, 'umls_description'] = [fix_format(x) for x in umls.get('umls_description').values]
umls = umls.drop_duplicates().astype(str)
disease_features = pd.merge(disease_features, umls, 'left').fillna('')

### Orphanet

In [18]:
mondo_xref = pd.read_csv(data_path+'mondo/mondo_references.csv').astype(str).query('ontology=="Orphanet"').drop(['ontology'],axis=1)

orphanet = pd.read_csv(data_path+'orphanet/orphanet.csv').drop(['UMLS'],axis=1)
orphanet.loc[:, 'disease_id'] = [x.split(':')[1] for x in orphanet.get('disease_id').values]
orphanet = orphanet.drop(['disease_shortname','disease_name'],axis=1).drop_duplicates()
orphanet = orphanet.replace('', float('nan')).replace('-', float('nan'))
orphanet = orphanet.rename(columns={'definition':'orphanet_definition'})

q = '@orphanet.get("orphanet_definition").isna() and @orphanet.get("prevalence").isna() and @orphanet.get("epidemiology").isna() and @orphanet.get("clinical_description").isna() and @orphanet.get("management_and_treatment").isna()'
orphanet = orphanet.drop(orphanet.query(q).index)
assert orphanet.query(q).empty

orphanet = pd.merge(orphanet, mondo_xref, 'left', left_on='disease_id', right_on='ontology_id')
orphanet = orphanet.drop(['disease_id','ontology_id'],axis=1).drop_duplicates()
orphanet = orphanet.fillna('').astype(str)

def fix_num(x): 
    x = x.replace(' / ','/')
    for expr in [r'\d{1} \d{3}',r'\d{1},\d{3}']: 
        res = re.search(expr, x)
        while res: 
            x = x[:res.start()+1] + x[res.start()+2:]
            res = re.search(expr, x)
    return x

orphanet.loc[:, 'prevalence'] = [fix_num(x) if x!="Unknown" else '' for x  in orphanet.get('prevalence').values]
orphanet.loc[:, 'epidemiology'] = [fix_num(x) for x  in orphanet.get('epidemiology').values]
orphanet = orphanet.rename(columns={'prevalence':'orphanet_prevalence', 'epidemiology':'orphanet_epidemiology', 
                                    'clinical_description':'orphanet_clinical_description', 
                                    'management_and_treatment':'orphanet_management_and_treatment'})
orphanet = orphanet.drop_duplicates()
disease_features = pd.merge(disease_features, orphanet, 'left').fillna('')

### Mayo clinic

Note: The code to manually identify mappings between names in Mayo Clinic and names in the RxData KG are provided in a separate file.

In [19]:
mayo_data = pd.read_csv(data_path+'mayoclinic/mayo.csv').astype(str)
mayo_map = pd.read_csv(data_path+'mayoclinic/mayo_kg_map.csv').astype(str)

mayo_data = mayo_data.drop(['link'],axis=1)

mayo_data = mayo_data.replace('None',float('nan')).replace('None.',float('nan')).fillna('')

# process symptom
symptoms, see_doc = [], []
for x in mayo_data.get(['Symptoms']).values.reshape(-1): 
    if 'doctor' in x: 
        symptoms.append(x[:x[:x.find('doctor')].rfind('\n')])
        see_doc.append(x[x[:x.find('doctor')].rfind('\n')+1:])
    else: 
        symptoms.append(x)
        see_doc.append('')
        
mayo_data.loc[:, 'Symptoms'] = symptoms
mayo_data.loc[:, 'see_doc'] = see_doc

def clean_text(x):
    if not x: return ''
    x = x.replace(':\n',': ').replace('.\n','. ').replace('\n',', ')
    x = re.sub('\(.*?\)', '', x) 
    x = x.replace('  ',' ')
    x = x.replace(' , ',', ')
    return x

mayo_data = pd.concat([mayo_data.get(['name']),mayo_data.drop(['name'],axis=1).applymap(clean_text)],axis=1)

mayo_data = pd.merge(mayo_data, mayo_map, 'inner', left_on='name', right_on='mayo_name')
mayo_data = mayo_data.get(['node_id', 'Symptoms', 'Causes', 'Risk factors', 'Complications',
       'Prevention', 'see_doc']).drop_duplicates()
mayo_data = mayo_data.rename(columns={'Symptoms':'mayo_symptoms', 'Causes':'mayo_causes', 'see_doc':'mayo_see_doc',
                                      'node_id':'mondo_id', 'Risk factors':'mayo_risk_factors', 
                                      'Complications':'mayo_complications', 'Prevention':'mayo_prevention',})

disease_features = pd.merge(disease_features, mayo_data, 'left').fillna('')

### Save and describe features

In [20]:
disease_features = disease_features.drop_duplicates()
disease_features.to_csv(save_path+'disease_features.csv', index=False)

In [21]:
disease_features_full  = disease_features.copy()

In [22]:
col_name = 'mondo_id'
disease_features = disease_features_full.get([col_name, 'mondo_definition', 'umls_description',
       'orphanet_definition', 'orphanet_prevalence', 'orphanet_epidemiology',
       'orphanet_clinical_description', 'orphanet_management_and_treatment',
       'mayo_symptoms', 'mayo_causes', 'mayo_risk_factors',
       'mayo_complications', 'mayo_prevention', 'mayo_see_doc'])

disease_features = disease_features.drop_duplicates().replace('',float('nan'))
disease_features = disease_features.loc[np.invert(disease_features.loc[:, disease_features.columns != col_name].isnull().all(1).values), :]

x = set(np.unique(diseases.get(col_name).values))
print('coverage of raw KG with {} total diseases'.format(len(x)))
y = set(np.unique(disease_features.get(col_name).values))
print('proportion of diseases with features {:.2f}'.format(len(y)/len(x)))
print()
print('#values  #unique feature')
print('{} \t {} \t {}'.format(disease_features.shape[0], len(y), 'any'))
for c in ['mondo_definition', 'umls_description',
       'orphanet_definition', 'orphanet_prevalence', 'orphanet_epidemiology',
       'orphanet_clinical_description', 'orphanet_management_and_treatment',
       'mayo_symptoms', 'mayo_causes', 'mayo_risk_factors',
       'mayo_complications', 'mayo_prevention', 'mayo_see_doc']: 
    
    df = disease_features.get([col_name, c]).dropna().drop_duplicates()
    y = set(np.unique(df.get(col_name).values))
    print('{} \t {} \t {}'.format(df.shape[0], len(y), c))

coverage of raw KG with 22205 total diseases
proportion of diseases with features 0.82

#values  #unique feature
40068 	 18152 	 any
15238 	 15238 	 mondo_definition
28468 	 8689 	 umls_description
6564 	 6548 	 orphanet_definition
3989 	 3989 	 orphanet_prevalence
2350 	 2348 	 orphanet_epidemiology
2294 	 2292 	 orphanet_clinical_description
1732 	 1731 	 orphanet_management_and_treatment
6642 	 5789 	 mayo_symptoms
6629 	 5776 	 mayo_causes
6284 	 5501 	 mayo_risk_factors
5011 	 4455 	 mayo_complications
2529 	 2273 	 mayo_prevention
5862 	 5234 	 mayo_see_doc


In [23]:
col_name = 'node_index'

disease_features = disease_features_full.get([col_name, 'mondo_definition', 'umls_description',
       'orphanet_definition', 'orphanet_prevalence', 'orphanet_epidemiology',
       'orphanet_clinical_description', 'orphanet_management_and_treatment',
       'mayo_symptoms', 'mayo_causes', 'mayo_risk_factors',
       'mayo_complications', 'mayo_prevention', 'mayo_see_doc'])

disease_features = disease_features.drop_duplicates().replace('',float('nan'))
disease_features = disease_features.loc[np.invert(disease_features.loc[:, disease_features.columns != col_name].isnull().all(1).values), :]

x = set(np.unique(kg_diseases.get(col_name).values))
print('coverage of grouped KG with {} total diseases'.format(len(x)))
y = set(np.unique(disease_features.get('node_index').values))
print('proportion of dieases with features {:.2f}'.format(len(y)/len(x)))
print()

print('#values  #unique feature')
print('{} \t {} \t {}'.format(disease_features.shape[0], len(y), 'any'))
for c in ['mondo_definition', 'umls_description',
       'orphanet_definition', 'orphanet_prevalence', 'orphanet_epidemiology',
       'orphanet_clinical_description', 'orphanet_management_and_treatment',
       'mayo_symptoms', 'mayo_causes', 'mayo_risk_factors',
       'mayo_complications', 'mayo_prevention', 'mayo_see_doc']: 
    
    df = disease_features.get([col_name, c]).dropna().drop_duplicates()
    y = set(np.unique(df.get(col_name).values))
    print('{} \t {} \t {}'.format(df.shape[0], len(y), c))

coverage of grouped KG with 17080 total diseases
proportion of dieases with features 0.83

#values  #unique feature
39800 	 14252 	 any
15238 	 12001 	 mondo_definition
25374 	 6964 	 umls_description
6562 	 5645 	 orphanet_definition
3500 	 3430 	 orphanet_prevalence
2335 	 2026 	 orphanet_epidemiology
2293 	 1972 	 orphanet_clinical_description
1722 	 1553 	 orphanet_management_and_treatment
5140 	 4470 	 mayo_symptoms
5128 	 4459 	 mayo_causes
4898 	 4299 	 mayo_risk_factors
3792 	 3396 	 mayo_complications
1907 	 1776 	 mayo_prevention
4531 	 4058 	 mayo_see_doc
