# Construct PheKG knowledge graph

In [1]:
import os
import pandas as pd
import numpy as np
import dgl 
import torch
import csv
from tqdm import tqdm

In [2]:
# set seeds
torch.manual_seed(0)
import random
random.seed(0)
import numpy as np
np.random.seed(0)

In [3]:
os.chdir(r"phekg/")

## Download relevant medical vocabularies

In [5]:
# creat subfolder to hold kg files 
!mkdir -p phekg
os.chdir(r"phekg/")
!pwd

/n/home01/ruthjohnson/github_sandbox/Clinical-knowledge-embeddings/kg/phekg/phekg


In [4]:
#
# PheMap and phecodes
#
# https://www.vumc.org/cpm/phemap

!wget https://phewascatalog.org/files/phemap/PheMap_Mapped_Terminologies_1.1.csv
!wget https://phewascatalog.org/files/phemap/Phecode_Relationship.csv 

!wget https://phewascatalog.org/files/phecode_definitions1.1.csv.zip 
!unzip phecode_definitions1.1.csv.zip

!wget https://phewascatalog.org/files/phecode_icd9_map_unrolled.csv.zip
!unzip phecode_icd9_map_unrolled.csv.zip

--2024-08-06 13:46:59--  https://phewascatalog.org/files/phemap/PheMap_Mapped_Terminologies_1.1.csv
Resolving phewascatalog.org (phewascatalog.org)... 160.129.28.255
Connecting to phewascatalog.org (phewascatalog.org)|160.129.28.255|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 235935924 (225M) [text/csv]
Saving to: ‘PheMap_Mapped_Terminologies_1.1.csv’

PheMap_Mapped_Termi   1%[                    ]   4.07M  1.91MB/s               ^C
--2024-08-06 13:47:01--  https://phewascatalog.org/files/phemap/Phecode_Relationship.csv
Resolving phewascatalog.org (phewascatalog.org)... 160.129.28.255
Connecting to phewascatalog.org (phewascatalog.org)|160.129.28.255|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 44138 (43K) [text/csv]
Saving to: ‘Phecode_Relationship.csv’


2024-08-06 13:47:02 (1.47 MB/s) - ‘Phecode_Relationship.csv’ saved [44138/44138]

--2024-08-06 13:47:02--  https://phewascatalog.org/files/phecode_definitions1.1.csv.zip
Reso

In [6]:
#
# UMLS
#

# This assumes that the file 'umls.csv' is provided. This is accessible by registering and downloading 
# the UMLS dataset (https://www.nlm.nih.gov/databases/umls.html)

!head -n 1 umls.csv  > umls_atc_rxnorm.csv
!grep -E 'ATC|RXNORM' umls.csv  >> umls_atc_rxnorm.csv
!cut -d',' -f1,3,4,5,6,7,8,9,10,11,12,13,14,15 umls_atc_rxnorm.csv | sed 's/^"//' | sed 's/"$//' > umls_atc_rxnorm_clean.csv

head: cannot open 'umls.csv' for reading: No such file or directory
grep: umls.csv: No such file or directory


In [4]:
umls_file = "umls_atc_rxnorm_clean.csv"
umls_df = pd.read_table("umls_atc_rxnorm_clean.csv", sep=',',  quoting=csv.QUOTE_NONE)
umls_df = umls_df.drop(['source_aui', 'source_cui', 'source_descriptor_dui'], axis=1)

# only keep active terms
umls_df = umls_df.query("term_status=='P'")

FileNotFoundError: [Errno 2] No such file or directory: 'umls_atc_rxnorm_clean.csv'

In [3]:
#
# Anatomical Therapeutic Chemical (ATC) Classification
#

# atc codes
atc_df = umls_df.query("source=='ATC'")

# remove atc codes < ATC4
atc_df = atc_df.assign(atc_len=[len(x) for x in atc_df['source_code']])
atc_df = atc_df.query("atc_len >= 5")

# rxnorm
rx_df = umls_df.query("source=='RXNORM'")

# overlap of atc and rxnorm by cui
uniq_rx_df = rx_df[['cui', 'source_code', 'source_name']].drop_duplicates() 
uniq_atc_df = atc_df[['cui', 'source_code', 'source_name']].drop_duplicates()

overlap_df = uniq_atc_df.merge(uniq_rx_df, on='cui', how='inner')
overlap_df.columns = ['cui', 'atc_code', 'atc_name', 'rxnorm_code', 'rxnorm_name']
uniq_overlap_df = overlap_df[['cui', 'atc_code', 'rxnorm_code']].drop_duplicates()

# filter to only codes used in phemap
phemap_df = pd.read_csv("PheMap_Mapped_Terminologies_1.1.csv")
phemap_rxnorm_codes = phemap_df.query("SOURCE=='RXNORM'")['CODE'].unique()

# save final mapping
atc_rxnorm_map_df = uniq_overlap_df.loc[uniq_overlap_df['rxnorm_code'].isin(phemap_rxnorm_codes)]
atc_rxnorm_map_df.to_csv("atc_rxnorm_map.csv", index=False)

NameError: name 'umls_df' is not defined

In [7]:
# original number of edges
phemap_df.loc[phemap_df['SOURCE'] != 'ICD10CM'].shape

(727939, 5)

In [11]:
# original number of nodes
phecodes = phemap_df.loc[phemap_df['SOURCE'] != 'ICD10CM']['PHECODE'].unique().tolist()

non_phecode_df = phemap_df.loc[phemap_df['SOURCE'] != 'ICD10CM'][['CODE', 'SOURCE']].drop_duplicates()

len(phecodes) + len(non_phecode_df)

78801

## Construct Basic KG

#### Use PheMap as basis for nodes

In [5]:
defn_df = pd.read_csv("phecode_definitions1.2.csv")

non_nan_cats = ['infectious diseases', 'other', 'neoplasms', 'endocrine/metabolic',
       'hematopoietic', 'mental disorders', 'neurological',
       'sense organs', 'circulatory system', 'respiratory', 'digestive',
       'genitourinary', 'pregnancy complications', 'dermatologic',
       'musculoskeletal', 'congenital anomalies', 'symptoms',
       'injuries & poisonings']

phecode_cats = ['sense organs', 'symptoms', 'hematopoietic', 'injuries & poisonings', 'other', 
 'mental disorders', 'respiratory', 'congenital anomalies', 'endocrine/metabolic',
 'infectious diseases', 'neoplasms', 'genitourinary', 'digestive', 'pregnancy complications',
 'circulatory system', 'dermatologic', 'neurological', 'musculoskeletal']

df = pd.read_csv("PheMap_Mapped_Terminologies_1.1.csv")

icd_df = df.loc[df['SOURCE']=='ICD9CM']

In [6]:
icd_df_list = []
for cat in phecode_cats:
    if cat == 'other':
        non_nan_inds = defn_df.loc[defn_df['category'].isin(non_nan_cats)]['phecode'].unique()  
        cat_df = icd_df.loc[~icd_df['PHECODE'].isin(non_nan_inds)] 
    else:
        cat_df = icd_df.loc[icd_df['PHECODE'].isin(defn_df.loc[defn_df['category'] == cat]['phecode'].unique())]
    icd_df_list.append(cat_df)
    print(cat, cat_df.shape)
    
icd_phemap_df = pd.concat(icd_df_list)
icd_phemap_df.shape

sense organs (35607, 5)
symptoms (8990, 5)
hematopoietic (13067, 5)
injuries & poisonings (24748, 5)
other (6150, 5)
mental disorders (22992, 5)
respiratory (21666, 5)
congenital anomalies (10948, 5)
endocrine/metabolic (25757, 5)
infectious diseases (16599, 5)
neoplasms (21657, 5)
genitourinary (23626, 5)
digestive (25468, 5)
pregnancy complications (15324, 5)
circulatory system (33869, 5)
dermatologic (16560, 5)
neurological (20926, 5)
musculoskeletal (21610, 5)


(365564, 5)

In [7]:
df = pd.concat([icd_phemap_df,
    df.loc[df['SOURCE'] == 'LNC'],
    df.loc[df['SOURCE'] == 'SNOMEDCT_US'],
    df.loc[df['SOURCE'] == 'RXNORM'],
    df.loc[df['SOURCE'] == 'CPT']])

df.shape

(727696, 5)

In [8]:
node_df = df[['CODE', 'SOURCE', 'DESCRIPTION']].drop_duplicates()

# add nodes not listed in phemap
relation_df = pd.read_csv("Phecode_Relationship.csv")
defn_df = pd.read_csv("phecode_definitions1.2.csv")

all_phecode_df = defn_df[['phecode', 'phenotype']]
all_phecode_df.columns = ['CODE', 'DESCRIPTION']
all_phecode_df=all_phecode_df.assign(SOURCE='PHECODE')

In [9]:
# make sure not missing any phecodes in phemap
df.loc[~df['PHECODE'].isin(all_phecode_df['CODE'])]['PHECODE'].unique()

array([1010.6])

In [10]:
# add phecodes to node df
missing_entry = pd.DataFrame({'CODE': [1010.6], 'SOURCE': ['PHECODE'],
             'DESCRIPTION': ["Persons encountering health services in circumstances related to reproduction"]})
node_df = pd.concat([node_df, all_phecode_df, missing_entry])    

In [11]:
node_df.shape

(82113, 3)

#### Aggregate duplicate codes with multiple descriptions

In [12]:
#
# Identify nodes with duplicate descriptions
#
dup_code_dict = {}

for source_type in node_df['SOURCE'].unique():
    print(source_type)
    node_count_df = node_df.query("SOURCE=='{}'".format(source_type)).groupby('CODE').count().reset_index()

    node_dup_df = node_count_df[node_count_df['DESCRIPTION'] > 1]

    for i, row in node_dup_df.iterrows():
        code = row['CODE']
        dup_code_dict[(source_type, code)] = node_df.loc[(node_df['SOURCE'] == source_type) 
                                                    & (node_df['CODE'] == code)]['DESCRIPTION'].values
# aggregate duplicates
dupe_code_types = {}

for source in node_df['SOURCE'].unique():
    l = [x for (k, v),x in zip(dup_code_dict,dup_code_dict.values()) if k==source]
    if len(l) > 0:
        dupe_code_types[source] = np.concatenate(l)
    else:
        dupe_code_types[source] = []

ICD9CM
LNC
SNOMEDCT_US
RXNORM
CPT
PHECODE


In [13]:
#
# Identify duplicates based on duplicate codes
#
dup_desc_dict = {}

for source_type in node_df['SOURCE'].unique():
    print(source_type)
    source_df = node_df.query("SOURCE=='{}'".format(source_type))
    
    # remove descriptions already accounted for in duplicate codes
    source_df = source_df.loc[~source_df['DESCRIPTION'].isin(dupe_code_types[source_type])]
    
    node_count_df = source_df.groupby('DESCRIPTION').count().reset_index()

    node_dup_df = node_count_df[node_count_df['CODE'] > 1]

    for i, row in node_dup_df.iterrows():
        desc = row['DESCRIPTION']
        dup_desc_dict[(source_type, desc)] = node_df.loc[(node_df['SOURCE'] == source_type) 
                                                    & (node_df['DESCRIPTION'] == desc)]['CODE'].values

ICD9CM
LNC
SNOMEDCT_US
RXNORM
CPT
PHECODE


In [16]:
[x for x in dup_desc_dict.keys() if x[0]=='LNC']

[('LNC', '1'),
 ('LNC', '1,3 beta glucan'),
 ('LNC', '1,4-Dichlorobenzene'),
 ('LNC', '14-3-3 protein'),
 ('LNC', '17-Hydroxyprogesterone'),
 ('LNC', '1H'),
 ('LNC', '1st trimester'),
 ('LNC', '2'),
 ('LNC', '2H'),
 ('LNC', '2nd trimester'),
 ('LNC', '3-Methylcrotonyl-CoA carboxylase deficiency'),
 ('LNC', '4'),
 ('LNC', "5'-Nucleotidase"),
 ('LNC', '5-Hydroxytryptophan'),
 ('LNC', '6'),
 ('LNC', '8'),
 ('LNC', '???lead'),
 ('LNC', 'A'),
 ('LNC', 'A Ag'),
 ('LNC', 'ABO & Rh group'),
 ('LNC', 'ABXBACT'),
 ('LNC', 'ALPRAZolam'),
 ('LNC', 'AP'),
 ('LNC', 'ARIPiprazole'),
 ('LNC', 'ART'),
 ('LNC', 'AV shunt'),
 ('LNC', 'Abacavir'),
 ('LNC', 'Abdomen'),
 ('LNC', 'Abdomen+'),
 ('LNC', 'Abdomen+Chest'),
 ('LNC', 'Abdomen+Pelvis>Genitourinary tract'),
 ('LNC', 'Abdomen+Pelvis>Iliac artery'),
 ('LNC', 'Abdomen>Abdominal wall'),
 ('LNC', 'Abdomen>Kidney'),
 ('LNC', 'Abdomen>Liver'),
 ('LNC', 'Abdomen>Renal vein.left'),
 ('LNC', 'Abdomen>Renal vein.right'),
 ('LNC', 'Abdomen>Right upper quadrant'

In [17]:
#
# Aggregate and rename nodes with duplicate description
#
node_source_list = []
node_id_list = []
node_name_list = []
node_desc_list = []

for k in dup_code_dict.keys():
    source = k[0]
    code = k[1]
    new_id = 'group_' + str(code)
    desc = dup_code_dict[k][0]
    
    node_source_list.append(source)
    node_id_list.append(new_id)
    node_name_list.append(code)
    node_desc_list.append(desc)

rename_node_code_df = pd.DataFrame({'SOURCE': node_source_list, 'CODE': node_id_list,
                               'OG_CODE': node_name_list, 'DESCRIPTION': node_desc_list})

In [18]:
rename_node_code_df.shape

(2507, 4)

In [20]:
rename_node_code_df.shape

(2507, 4)

In [16]:
#
# Aggregate and label nodes with duplicate codes
#
node_source_list = []
node_id_list = []
node_name_list = []
node_desc_list = []

for k in dup_desc_dict.keys():
    source = k[0]
    desc = k[1]
    dup_list = dup_desc_dict[k]
    new_id = 'group_' + str(dup_list[0])
    
    node_source_list.append(source)
    node_id_list.append(new_id)
    node_name_list.append(desc)

rename_node_desc_df = pd.DataFrame({'SOURCE': node_source_list, 'CODE': node_id_list,
                               'DESCRIPTION': node_name_list})

rename_node_desc_df.shape

(10101, 3)

In [17]:
# remove renamed nodes from original node df (by code)
remaining_node_list = []
for source in rename_node_code_df['SOURCE'].unique():
    print(source)
    remove_code = rename_node_code_df.loc[rename_node_code_df['SOURCE'] == source]['OG_CODE']
    keep_df = node_df.loc[(node_df['SOURCE'] == source) & (~node_df['CODE'].isin(remove_code))]
    exclude_list = rename_node_desc_df.loc[rename_node_desc_df['SOURCE'] == source]['DESCRIPTION'].values
    keep_df = keep_df.loc[~keep_df['DESCRIPTION'].isin(exclude_list)]
    
    remaining_node_list.append(keep_df)

# add back sources that didn't have any duplicates (i.e. rxnorm)
for source in set(node_df['SOURCE'].unique()) - set(rename_node_code_df['SOURCE'].unique()):
    keep_df = node_df.loc[(node_df['SOURCE'] == source)]
    exclude_list = rename_node_desc_df.loc[rename_node_desc_df['SOURCE'] == source]['DESCRIPTION'].values
    keep_df = keep_df.loc[~keep_df['DESCRIPTION'].isin(exclude_list)]
    
    remaining_node_list.append(keep_df)
    
# add back nodes that were grouped
remaining_node_list.append(rename_node_code_df.drop(['OG_CODE'], axis=1))

len(remaining_node_list)

ICD9CM
SNOMEDCT_US
CPT


7

In [18]:
# remove renamed nodes from original node df
for source in rename_node_desc_df['SOURCE'].unique():
    print(source)
    remove_desc = rename_node_desc_df.loc[rename_node_desc_df['SOURCE'] == source]['DESCRIPTION']
    keep_df = node_df.loc[(node_df['SOURCE'] == source) 
                          & (~node_df['DESCRIPTION'].isin(remove_desc))]
    exclude_list = rename_node_code_df.loc[rename_node_code_df['SOURCE'] == source]['OG_CODE'].values
    keep_df = keep_df.loc[~keep_df['CODE'].isin(exclude_list)]
    
    remaining_node_list.append(keep_df)

# add back sources that didn't have any duplicates (i.e. rxnorm)
for source in set(node_df['SOURCE'].unique()) - set(rename_node_desc_df['SOURCE'].unique()):
    keep_df = node_df.loc[(node_df['SOURCE'] == source)]
    exclude_list = rename_node_code_df.loc[rename_node_code_df['SOURCE'] == source]['OG_CODE'].values
    keep_df = keep_df.loc[~keep_df['CODE'].isin(exclude_list)]
    remaining_node_list.append(keep_df)
    
# add back nodes that were grouped
remaining_node_list.append(rename_node_desc_df)

new_node_df = pd.concat(remaining_node_list).drop_duplicates()

new_node_df.shape

ICD9CM
LNC
SNOMEDCT_US
CPT


(66756, 3)

In [19]:
new_node_df.tail()

Unnamed: 0,CODE,SOURCE,DESCRIPTION
10096,group_73120,CPT,X-ray of hand
10097,group_70210,CPT,X-ray of paranasal sinuses
10098,group_72170,CPT,X-ray of pelvis
10099,group_70190,CPT,X-ray of skull
10100,group_70300,CPT,X-ray of teeth


In [20]:
#
# merge with edges in 3 parts
#
new_node_mapping = new_node_df.copy()[['CODE', 'SOURCE', 'DESCRIPTION']]
new_node_mapping.columns = ['CODE_new', 'SOURCE', 'DESCRIPTION']

merge_df_list = []

for source in node_df['SOURCE'].unique():
    hold_back_codes = rename_node_code_df.loc[rename_node_code_df['SOURCE']==source]['OG_CODE'].values

    hold_back_desc = rename_node_desc_df.loc[rename_node_desc_df['SOURCE']==source]['DESCRIPTION'].values

    source_df = df.loc[df['SOURCE']==source]

    source_keep_df = source_df.loc[(~source_df['CODE'].isin(hold_back_codes)) &
                  (~source_df['DESCRIPTION'].isin(hold_back_desc))]

    new_source_keep_df = source_keep_df.merge(new_node_mapping, on=['DESCRIPTION', 'SOURCE'], how='inner').drop('CODE', axis=1)

    new_source_keep_df = source_keep_df.rename(columns={'CODE_new': 'CODE'})[['PHECODE', 'CODE', 'SOURCE', 'DESCRIPTION', 'TFIDF']]

    merge_df_list.append(new_source_keep_df)

    hold_back_code_df = source_df.loc[source_df['CODE'].isin(hold_back_codes)]

    hold_back_code_df = hold_back_code_df.drop(['DESCRIPTION'], axis=1).groupby(
        ['PHECODE', 'CODE', 'SOURCE']).mean().reset_index()

    new_hold_back_code_df = hold_back_code_df.merge(rename_node_code_df, 
                            left_on=['SOURCE', 'CODE'],
                            right_on=['SOURCE', 'OG_CODE'], how='inner')

    new_hold_back_code_df['CODE'] = new_hold_back_code_df['CODE_y'].copy()
    new_hold_back_code_df = new_hold_back_code_df.drop(['CODE_x', 'CODE_y', 'OG_CODE'], axis=1)

    merge_df_list.append(new_hold_back_code_df)

    hold_back_desc_df = source_df.loc[source_df['DESCRIPTION'].isin(hold_back_desc)]

    hold_back_desc_df = hold_back_desc_df.drop(['CODE'], axis=1).groupby(
        ['PHECODE', 'DESCRIPTION', 'SOURCE']).mean().reset_index()

    new_hold_back_desc_df = hold_back_desc_df.merge(rename_node_desc_df, 
                            left_on=['SOURCE', 'DESCRIPTION'],
                            right_on=['SOURCE', 'DESCRIPTION'], how='inner')

    merge_df_list.append(new_hold_back_desc_df)

    
new_df = pd.concat(merge_df_list).drop_duplicates()

phecode_df = node_df.query("SOURCE=='PHECODE'").reset_index().drop('index', axis=1)
new_node_df = new_node_df.reset_index().drop('index', axis=1)

new_node_df.shape

(66756, 3)

In [21]:
new_df.tail()

Unnamed: 0,PHECODE,CODE,SOURCE,DESCRIPTION,TFIDF
504,1001.0,group_93451,CPT,Right heart catheterization,6.88946
505,1010.0,group_86480,CPT,Tuberculosis test,1.51313
506,1010.6,group_59409,CPT,Vaginal delivery,2.541951
507,1011.0,group_20200,CPT,Biopsy of muscle,1.370796
508,1011.0,group_90675,CPT,Rabies vaccine,1.376661


#### Add edges between parent and children phecodes

In [22]:
# add parent relation nodes
relation_df = pd.read_csv("Phecode_Relationship.csv")

parent_df = relation_df[['PARENT_CODE', 'PARENT_STR']].drop_duplicates()

parent_df = parent_df.merge(defn_df[['phenotype', 'phecode']], left_on='PARENT_CODE',
                right_on='phecode', how='left')

# node_index,node_id,node_type,node_name,node_source
new_node_df.columns = ['node_id', 'node_source', 'node_name']
new_node_df['node_type'] = new_node_df['node_source'].copy()
new_node_df['node_index'] = new_node_df.index.values

new_node_df.shape

(66756, 5)

In [23]:
# translate to edges
new_df.columns = ['x_id', 'y_id', 'y_source', 'y_name', 'CS']

# define edges
edge_df = new_df.merge(new_node_df.loc[new_node_df['node_source']=='PHECODE'],
         left_on='x_id', right_on='node_id', how='left')

edge_df = edge_df.rename(columns={'node_source': 'x_source',
                        'node_name': 'x_name', 'node_type': 'x_type',
                       'node_index': 'x_index'}).drop(['node_id'], axis=1)

edge_df['y_type'] = edge_df['y_source'].copy()

edge_df = edge_df.merge(new_node_df,left_on=['y_id', 'y_source'], 
              right_on=['node_id', 'node_source'], how='left')

edge_df = edge_df.rename(columns={'node_index': 'y_index'})

edge_df = edge_df[['x_index', 'x_id', 'x_name', 'x_type', 'x_source', 
        'y_index', 'y_id', 'y_name', 'y_type', 'y_source', 'CS']]
edge_df.shape

(635712, 11)

In [24]:
# add parent edges
parent_child_df = relation_df[['CHILD_CODE', 'PARENT_CODE']]
parent_child_df.columns = ['x_id', 'y_id']
parent_child_df['x_source'] = 'PHECODE'
parent_child_df['y_source'] = 'PHECODE'
parent_child_df['x_type'] = 'PHECODE'
parent_child_df['y_type'] = 'PHECODE'

# get x info
parent_child_df = parent_child_df.merge(new_node_df.loc[new_node_df['node_source'] == 'PHECODE'],
                     left_on='x_id', right_on='node_id', how='left')

parent_child_df = parent_child_df.rename(columns={'node_name':'x_name',
                               'node_index': 'x_index'})


parent_child_df = parent_child_df.drop(['node_type', 'node_source', 'node_id'], axis=1)

# y info
parent_child_df = parent_child_df.merge(new_node_df.loc[new_node_df['node_source'] == 'PHECODE'],
                     left_on='y_id', right_on='node_id', how='left')

parent_child_df = parent_child_df.rename(columns={
                                                  'node_name':'y_name',
                               'node_index': 'y_index'}).drop(['node_type', 'node_source', 'node_id'], axis=1)

parent_child_df['CS'] = edge_df['CS'].max()

parent_child_df = parent_child_df[['x_index', 'x_id', 'x_name', 'x_type', 'x_source', 
        'y_index', 'y_id', 'y_name', 'y_type', 'y_source', 'CS']]

all_edge_df = pd.concat([edge_df, parent_child_df])

all_edge_df['relation'] = all_edge_df['x_type'] + '_' + all_edge_df['y_type']
print(all_edge_df.shape)

node_check_df = pd.concat([all_edge_df[['x_id', 'x_source']], all_edge_df[['y_id', 'y_source']].rename(columns={'y_id': 'x_id', 'y_source': 'x_source'})]).drop_duplicates()
node_check_df = node_check_df.rename(columns={'x_id': 'node_id', 'x_source': 'node_source'})


(636737, 12)


In [25]:
phecode_df = node_df.query("SOURCE=='PHECODE'").reset_index().drop('index', axis=1)

new_node_df = new_node_df.merge(node_check_df, on=['node_id', 'node_source'], how='inner')

new_node_df = new_node_df.reset_index().drop('index', axis=1)

# add parent relation nodes
relation_df = pd.read_csv("Phecode_Relationship.csv")

parent_df = relation_df[['PARENT_CODE', 'PARENT_STR']].drop_duplicates()

parent_df = parent_df.merge(defn_df[['phenotype', 'phecode']], left_on='PARENT_CODE',
                right_on='phecode', how='left')


new_node_df['node_index'] = new_node_df.index.values

# translate to edges
new_df.columns = ['x_id', 'y_id', 'y_source', 'y_name', 'CS']


# define edges
edge_df = new_df.merge(new_node_df.loc[new_node_df['node_source']=='PHECODE'],
         left_on='x_id', right_on='node_id', how='left')

edge_df = edge_df.rename(columns={'node_source': 'x_source',
                        'node_name': 'x_name', 'node_type': 'x_type',
                       'node_index': 'x_index'}).drop(['node_id'], axis=1)

edge_df['y_type'] = edge_df['y_source'].copy()

edge_df = edge_df.merge(new_node_df,left_on=['y_id', 'y_source'], 
              right_on=['node_id', 'node_source'], how='left')

edge_df = edge_df.rename(columns={'node_index': 'y_index'})

edge_df = edge_df[['x_index', 'x_id', 'x_name', 'x_type', 'x_source', 
        'y_index', 'y_id', 'y_name', 'y_type', 'y_source', 'CS']]

# add parent edges
parent_child_df = relation_df[['CHILD_CODE', 'PARENT_CODE']]
parent_child_df.columns = ['x_id', 'y_id']
parent_child_df['x_source'] = 'PHECODE'
parent_child_df['y_source'] = 'PHECODE'
parent_child_df['x_type'] = 'PHECODE'
parent_child_df['y_type'] = 'PHECODE'

# get x info
parent_child_df = parent_child_df.merge(new_node_df.loc[new_node_df['node_source'] == 'PHECODE'],
                     left_on='x_id', right_on='node_id', how='left')

parent_child_df = parent_child_df.rename(columns={'node_name':'x_name',
                               'node_index': 'x_index'})


parent_child_df = parent_child_df.drop(['node_type', 'node_source', 'node_id'], axis=1)

# y info
parent_child_df = parent_child_df.merge(new_node_df.loc[new_node_df['node_source'] == 'PHECODE'],
                     left_on='y_id', right_on='node_id', how='left')

parent_child_df = parent_child_df.rename(columns={
                                                  'node_name':'y_name',
                               'node_index': 'y_index'}).drop(['node_type', 'node_source', 'node_id'], axis=1)

parent_child_df['CS'] = edge_df['CS'].max()

parent_child_df = parent_child_df[['x_index', 'x_id', 'x_name', 'x_type', 'x_source', 
        'y_index', 'y_id', 'y_name', 'y_type', 'y_source', 'CS']]

all_edge_df = pd.concat([edge_df, parent_child_df])

all_edge_df['relation'] = all_edge_df['x_type'] + '_' + all_edge_df['y_type']

node_check_df = pd.concat([all_edge_df[['x_id', 'x_source']], all_edge_df[['y_id', 'y_source']].rename(columns={'y_id': 'x_id', 'y_source': 'x_source'})]).drop_duplicates()

node_check_df = node_check_df.rename(columns={'x_id': 'node_id', 'x_source': 'node_source'})

node_check_df.shape

(66650, 2)

#### Put all nodes and edges together into a dictionary

In [26]:
kg_df = all_edge_df.drop_duplicates()

# re-index according to node type
nodes_df = new_node_df.merge(node_check_df, on=['node_id', 'node_source'], how='inner')

nodes_df['node_type_index'] = nodes_df.groupby('node_type').cumcount()
kg_df['x_type_index'] = nodes_df.loc[kg_df['x_index'], 'node_type_index'].values
kg_df['y_type_index'] = nodes_df.loc[kg_df['y_index'], 'node_type_index'].values

grouped_edges = kg_df.groupby(['x_type', 'relation', 'y_type'], sort = False)

# add all edges into a dict
data_dict = {}

# construct graph from edges
for (x_type, relation, y_type), edges_subset in grouped_edges:

    # convert edge indices to torch tensor
    edge_indices = (torch.tensor(edges_subset['x_type_index'].values), torch.tensor(edges_subset['y_type_index'].values))

    # add edge indices to data object
    data_dict[(x_type, relation, y_type)] = edge_indices

In [27]:
#
# Add reverse edges for phecode parent-child relationships
#
data_dict_undir = data_dict.copy()

etypes = list(data_dict_undir.keys())

for e_tup in etypes:
    if e_tup[1] == 'PHECODE_PHECODE':
        new_e_type = (e_tup[2], 'PHECODE_PHECODE_rev', e_tup[0])
    else:
        new_e_type = (e_tup[2], e_tup[1], e_tup[0])
    u_inds = data_dict_undir[e_tup][1]
    v_inds = data_dict_undir[e_tup][0]

    data_dict_undir[new_e_type] = (u_inds, v_inds)

#### Make DGL graph obj

In [28]:
hg = dgl.heterograph(data_dict)

rev_hg = dgl.heterograph(data_dict_undir)

In [29]:
# save
fname = "phemap_het_graph_undir.pt"
dgl.save_graphs(fname, [rev_hg])

In [30]:
# iterate over the groups and add global node indices to the graph
for (x_type, relation, y_type), edges_subset in grouped_edges:
    kg_indices = torch.vstack((torch.tensor(edges_subset['x_index'].values), 
                               torch.tensor(edges_subset['y_index'].values))).T
    graph_indices = torch.vstack((hg.edges(etype = (x_type, relation, y_type), form='uv')[0],
                        hg.edges(etype = (x_type, relation, y_type), form = 'uv')[1])).T
    
    hg[(x_type, relation, y_type)].edata['kg_index'] = torch.tensor(kg_indices)
    hg[(x_type, relation, y_type)].edata['graph_index'] = torch.tensor(graph_indices)

df_list = []
for (x_type, relation, y_type) in hg.canonical_etypes:
    x_graph_index = hg[(x_type, relation, y_type)].edata['graph_index'][:,0]
    y_graph_index = hg[(x_type, relation, y_type)].edata['graph_index'][:,1]
    x_kg_index = hg[(x_type, relation, y_type)].edata['kg_index'][:,0]
    y_kg_index = hg[(x_type, relation, y_type)].edata['kg_index'][:,1]
    
    df = pd.DataFrame({'graph_index': x_graph_index, 'kg_index': x_kg_index})
    df['ntype'] = x_type
    df_list.append(df)
    
    df = pd.DataFrame({'graph_index': y_graph_index, 'kg_index': y_kg_index})
    df['ntype'] = y_type
    df_list.append(df)

all_df = pd.concat(df_list).drop_duplicates()

node_map_df = all_df.merge(nodes_df[['node_index', 'node_name', 'node_id']].drop_duplicates(), left_on='kg_index', right_on='node_index', how='left').drop(['node_index'], axis=1)

  hg[(x_type, relation, y_type)].edata['kg_index'] = torch.tensor(kg_indices)
  hg[(x_type, relation, y_type)].edata['graph_index'] = torch.tensor(graph_indices)


In [31]:
node_map_df.shape

(66650, 5)

In [32]:
node_map_df.tail()

Unnamed: 0,graph_index,kg_index,ntype,node_name,node_id
66645,37746,65747,SNOMEDCT_US,Suspected child abuse,group_139863006
66646,33733,61734,SNOMEDCT_US,Congenital preauricular fistula,group_123283003
66647,33808,61809,SNOMEDCT_US,Costello syndrome,group_205803001
66648,34829,62830,SNOMEDCT_US,Fistula of intestine NOS,group_197245009
66649,34673,62674,SNOMEDCT_US,FH: Fragility fracture,group_389284009


In [33]:
node_map_df.to_csv("node_map_df.csv", index=False)

In [34]:
# add node features (kg index)
for node_type in hg.ntypes:
    print(node_type)
    hg.nodes[node_type].data['kg_index'] = torch.tensor(node_map_df.loc[node_map_df['ntype'] == node_type].sort_values(
        by='graph_index')['kg_index'].values)

CPT
ICD9CM
LNC
PHECODE
RXNORM
SNOMEDCT_US


In [35]:
# add node features (edge index)
cat_i = 0
for node_type in hg.ntypes:
    print(node_type)
    hg.nodes[node_type].data['type_index'] = torch.full((1, hg.num_nodes(node_type)),
                                                        cat_i).squeeze(dim=0)

CPT
ICD9CM
LNC
PHECODE
RXNORM
SNOMEDCT_US


In [36]:
# save!
fname = "phemap_het_graph.pt"
dgl.save_graphs(fname, [hg])

In [37]:
print(hg.num_nodes())
print(hg.num_edges())

66650
636737


## Add extra edges for terminology connections

In [38]:
# read in pre-computed basic kg and mapping for convenience
hg = dgl.load_graphs("phemap_het_graph_undir.pt")[0][0]
node_map_df = pd.read_csv("node_map_df.csv")
node_map_df.shape

(66650, 5)

#### Add PHECODE-ICD edges 

In [12]:
phecode_map_df = pd.read_csv("phecode_icd9_map_unrolled.csv")

phecode_map_df.tail()

Unnamed: 0,icd9,phecode
20778,V85.51,260.0
20779,V85.53,278.0
20780,V85.54,278.1
20781,V85.54,278.0
20782,V87.41,197.0


In [13]:
# total number of ICDs represented in phecode
len(phecode_map_df['icd9'].unique())

13707

In [14]:
# total number of phecodes
len(phecode_map_df['phecode'].unique())

1817

In [15]:
# total concepts just in phecode df 
len(phecode_map_df['icd9'].unique()) + len(phecode_map_df['phecode'].unique())

15524

In [17]:
# total edges just in phecode df
phecode_map_df.shape

(20783, 2)

In [40]:
# read in mapping files
group_df = node_map_df.loc[node_map_df['node_id'].str.startswith('group_')]
group_df = group_df.assign(new_node_id=group_df['node_id'].str.split('_', expand=True)[1])
group_df = group_df.assign(node_id=group_df['new_node_id'])
group_df = group_df.drop(['new_node_id'], axis=1)
print(group_df.shape)

og_df = node_map_df.loc[~node_map_df['node_id'].str.startswith('group_')]
node_map_df = pd.concat([og_df, group_df])
print(node_map_df.shape)

phecode_map_df['phecode'] = phecode_map_df['phecode'].astype(str)
print(phecode_map_df.shape)

phecode_map_df = phecode_map_df.merge(node_map_df.loc[node_map_df['ntype'] == 'ICD9CM'][['node_id', 'graph_index']], left_on='icd9', right_on='node_id').drop(['node_id'], axis=1)
print(phecode_map_df.shape)


(12608, 5)
(66650, 5)
(20783, 2)
(19705, 3)


In [41]:
phecode_map_df = phecode_map_df.merge(node_map_df.loc[node_map_df['ntype'] == 'PHECODE'][['node_id', 'graph_index']], left_on='phecode', right_on='node_id')
phecode_map_df.columns = ['icd9', 'phecode', 'icd9_index', 'node_id', 'phecode_index'] 
print(phecode_map_df.shape)

icd_phecode_index_map = {}
for i, row in phecode_map_df.sort_values(by='phecode_index').iterrows():
    k = row['icd9_index']
    v = row['phecode_index']
    icd_phecode_index_map[k] = v
print(len(icd_phecode_index_map.keys()))

(19463, 5)
12749


In [42]:
# get all PHECODE - ICD9CM edges
all_edges = hg.edges(etype=('PHECODE', 'PHECODE_ICD9CM', 'ICD9CM'))
phecode_src_nodes = all_edges[0]
icd_dst_nodes = all_edges[1]
# add phecode friends from phecode-icd associations
mapped_icd_dst = [icd_phecode_index_map[x.item()] if x.item() in icd_phecode_index_map.keys() else -1 for x in icd_dst_nodes]

keep_edge_list = []
for u,v in zip(phecode_src_nodes.tolist(), mapped_icd_dst):
    if v != -1:
        keep_edge_list.append([u, v])

# why so many duplicate phecode-phecodes?
keep_edge_df = pd.DataFrame(keep_edge_list)
keep_edge_df = keep_edge_df.loc[keep_edge_df[0] != keep_edge_df[1]].drop_duplicates()
print(keep_edge_df.shape)

new_phecode_src_ids = torch.tensor(keep_edge_df[0].values)
new_phecode_dst_ids = torch.tensor(keep_edge_df[1].values)

(34496, 2)


In [43]:
# forward phecode edge
hg.add_edges(new_phecode_src_ids,
             new_phecode_dst_ids,
             etype=('PHECODE', 'PHECODE_PHECODE', 'PHECODE'))

# rev phecode edge
hg.add_edges(new_phecode_dst_ids,
             new_phecode_src_ids,
             etype=('PHECODE', 'PHECODE_PHECODE_rev', 'PHECODE'))

# ICD9CM -> PHECODE
hg.add_edges(torch.tensor(phecode_map_df['icd9_index'].values),
             torch.tensor(phecode_map_df['phecode_index'].values),
             etype=('ICD9CM', 'PHECODE_ICD9CM', 'PHECODE'))

# PHECODE -> ICD9CM
hg.add_edges(torch.tensor(phecode_map_df['phecode_index'].values),
             torch.tensor(phecode_map_df['icd9_index'].values),
             etype=('PHECODE', 'PHECODE_ICD9CM', 'ICD9CM'))

In [44]:
print(hg.num_nodes())
print(hg.num_edges())

66650
1381392


#### Add ATC-RXNORM edges

In [45]:
# read in pre-computed atc-rxnorm mapping
atc_rxnorm_map_df = pd.read_csv("atc_rxnorm_map.csv")

# create atc4 from atc5
atc_rxnorm_map_df = atc_rxnorm_map_df.assign(atc4=atc_rxnorm_map_df['atc_code'].str.slice(0,5))
atc_bridge_dict = {}

atc4_list = np.sort(atc_rxnorm_map_df['atc4'].unique())
for atc_code in tqdm(atc4_list):
    rx_list = atc_rxnorm_map_df.loc[atc_rxnorm_map_df['atc4'] == atc_code]['rxnorm_code'].unique().tolist()
    atc_bridge_dict[atc_code] = rx_list


100%|██████████| 474/474 [00:00<00:00, 3495.04it/s]


In [47]:
# convert to homogenous graph for global indexing
homo_hg = dgl.to_homogeneous(hg)

ntype_dict = {}
for ntype_ind, ntype in zip(range(0, len(hg.ntypes)), hg.ntypes):
    ntype_dict[ntype_ind] = ntype

global_node_list = []
for new_id, old_id, type_ind in zip(homo_hg.nodes().tolist(), homo_hg.ndata['_ID'].tolist(), homo_hg.ndata['_TYPE'].tolist()):
    ntype_name = ntype_dict[type_ind]
    global_node_list.append([old_id, ntype_name, new_id])

global_node_df = pd.DataFrame(global_node_list)
global_node_df.columns = ['graph_index', 'ntype', 'global_graph_index']

node_map_df = node_map_df.merge(global_node_df, on=['ntype', 'graph_index'])
#temp_df = node_map_df.merge(global_node_df, on=['ntype', 'graph_index'])

In [48]:
edata_df = pd.DataFrame({'src': homo_hg.edges(form='uv')[0].tolist(), 'dst': homo_hg.edges(form='uv')[1].tolist(), 'etype': homo_hg.edata['_TYPE'].tolist()})
edata_df = edata_df.merge(node_map_df[['node_name', 'ntype', 'global_graph_index']], left_on='src', right_on='global_graph_index')
edata_df = edata_df.merge(node_map_df[['node_name', 'ntype', 'global_graph_index']], left_on='dst', right_on='global_graph_index')
edata_df.shape

(1381392, 9)

In [49]:
sorted_atc_codes = np.sort(list(atc_bridge_dict.keys()))

new_ind = homo_hg.num_nodes()
rxnorm_id_dict = {}
for k in sorted_atc_codes:
    rxnorm_id_dict[k] = [new_ind]
    new_ind += 1

In [50]:
# make new nodes for atc
n_atc = len(atc_bridge_dict.keys())

# new node type: ATC - 6
homo_hg = dgl.add_nodes(homo_hg, n_atc, data={'_TYPE': torch.tensor([6]*n_atc)})

In [51]:
homo_hg.num_edges()

1381392

In [52]:
rxnorm_id_to_index = {}
for i, row in node_map_df.loc[node_map_df['ntype']=='RXNORM'].iterrows():
    rx_code = str(row['node_id'])
    global_ind = row['global_graph_index']
    rxnorm_id_to_index[rx_code] = global_ind

In [53]:
src_list = []
dst_list = []

for k in rxnorm_id_dict.keys():
    # get graph index of new atc code
    atc_graph_index = rxnorm_id_dict[k][0]
    
    # connect atc code with rxnorm children
    for rxnorm_id in atc_bridge_dict[k]:
        rxnorm_id = str(rxnorm_id)
        rxnorm_index = rxnorm_id_to_index[rxnorm_id]
        
        src_list.append(atc_graph_index)
        dst_list.append(rxnorm_index)

In [54]:
src_list[0:10]

[66650, 66651, 66651, 66651, 66651, 66651, 66651, 66651, 66651, 66651]

In [55]:
# update node map
atc_list = []
for k in rxnorm_id_dict.keys():
    atc_graph_index = rxnorm_id_dict[k][0]
    atc_node_name = 'ATC4_' + str(k)
    atc_list.append(['ATC4', atc_node_name, k, atc_graph_index])
    
atc_df = pd.DataFrame(atc_list)
atc_df.columns = ['ntype', 'node_name', 'node_id', 'global_graph_index']
new_node_df = pd.concat([node_map_df[['ntype', 'node_name', 'node_id', 'global_graph_index']], atc_df])

In [56]:
new_node_df.tail()

Unnamed: 0,ntype,node_name,node_id,global_graph_index
469,ATC4,ATC4_V04CM,V04CM,67119
470,ATC4,ATC4_V04CX,V04CX,67120
471,ATC4,ATC4_V06DC,V06DC,67121
472,ATC4,ATC4_V08AB,V08AB,67122
473,ATC4,ATC4_V08CA,V08CA,67123


In [57]:
new_node_df.to_csv("new_node_map_df.csv", index=False)
node_map_df = new_node_df.copy()

In [58]:
edge_df = pd.DataFrame({'src': src_list, 'dst': dst_list})
edge_df = edge_df.merge(node_map_df, left_on='src', right_on='global_graph_index')
edge_df = edge_df.merge(node_map_df, left_on='dst', right_on='global_graph_index')
edge_df.shape

(1200, 10)

In [59]:
# forward edge: atc -> rx
homo_hg = dgl.add_edges(homo_hg, u=torch.tensor(src_list), v=torch.tensor(dst_list), data={'_TYPE': torch.tensor([12]*len(src_list))})

# backward edge
# rx -> atc
homo_hg = dgl.add_edges(homo_hg, u=torch.tensor(dst_list), v=torch.tensor(src_list), data={'_TYPE': torch.tensor([13]*len(src_list))})

In [60]:
print(homo_hg.num_nodes())
print(homo_hg.num_edges())

67124
1383792


#### Phecode to ATC edges

In [61]:
atc_id_to_index = {}
for i, row in node_map_df.loc[node_map_df['ntype']=='ATC4'].iterrows():
    atc_id = row['node_id']
    atc_index = row['global_graph_index']
    atc_id_to_index[atc_id] = atc_index

In [62]:
atc_bridge_dict['A06AD']

[6218, 6582, 6585, 6628, 9945, 28395, 36709, 52356]

In [63]:
# mapping rxnorm-graph-index -> atc-graph-index
rxnorm_atc_index_mapping = {}

for atc_code in atc_bridge_dict.keys():
    rx_id_list = atc_bridge_dict[atc_code]
    atc_graph_index = atc_id_to_index[atc_code]
    
    for rx_id in rx_id_list:
        rx_graph_index = rxnorm_id_to_index[str(rx_id)]
        rxnorm_atc_index_mapping[rx_graph_index] = atc_graph_index

In [65]:
src_list = homo_hg.edges(form='uv')[0].tolist()
dst_list = homo_hg.edges(form='uv')[1].tolist()
etype_inds = homo_hg.edata['_TYPE'].tolist()

In [66]:

edge_df = pd.DataFrame({'src': src_list, 'dst': dst_list, 'e_ind': etype_inds})
edge_df = edge_df.merge(node_map_df, left_on='src', right_on='global_graph_index')
edge_df = edge_df.merge(node_map_df, left_on='dst', right_on='global_graph_index')
edge_df.shape

(1383792, 11)

In [67]:
# u (src): PHECODE
# v (dst): RXNORM (RXNORM) -> ATC
src_list = []
dst_list = []

for u_nodes, v_nodes, etype in zip(homo_hg.edges(form='uv')[0], homo_hg.edges(form='uv')[1], homo_hg.edata['_TYPE']):
    if etype == 8: # phecode - rxnorm edge
        #print(u_nodes, v_nodes)
        rx_index = v_nodes.item()
        if rx_index in rxnorm_atc_index_mapping.keys():
            atc_index = rxnorm_atc_index_mapping[rx_index]
            src_list.append(u_nodes.item())
            dst_list.append(atc_index)

In [68]:
dst_list[0:10]

[67034, 67103, 67033, 67103, 66676, 67079, 67033, 66870, 67079, 67034]

In [69]:
edge_df = pd.DataFrame({'src': src_list, 'dst': dst_list}).drop_duplicates()
edge_df = edge_df.merge(node_map_df, left_on='src', right_on='global_graph_index')
edge_df = edge_df.merge(node_map_df, left_on='dst', right_on='global_graph_index')
edge_df.shape

(2732, 10)

In [70]:
# remove duplicates 
src_list = edge_df['src'].values
dst_list = edge_df['dst'].values

len(src_list)

2732

In [71]:
# forward edge: phecode -> atc (14)
homo_hg = dgl.add_edges(homo_hg, u=torch.tensor(src_list), v=torch.tensor(dst_list), data={'_TYPE': torch.tensor([14]*len(src_list))})

# backward edge
# atc -> phecode (15)
homo_hg = dgl.add_edges(homo_hg, u=torch.tensor(dst_list), v=torch.tensor(src_list), data={'_TYPE': torch.tensor([15]*len(src_list))})

In [72]:
print(homo_hg.num_nodes()) # 67124
print(homo_hg.num_edges()) # 1389256

67124
1389256


In [73]:
# save!
fname = "intermediate_phemap_homo_graph.pt"
dgl.save_graphs(fname, [homo_hg])

## Filter high degree nodes

In [21]:
# read in pre-computed basic kg and mapping for convenience
fname = "intermediate_phemap_homo_graph.pt"
homo_hg = dgl.load_graphs(fname)[0][0]
node_map_df = pd.read_csv("new_node_map_df.csv")

In [22]:
phecode_node_list = node_map_df.loc[node_map_df['ntype'] == 'PHECODE']['global_graph_index'].values.tolist()
phecode_node_degree = homo_hg.in_degrees(torch.tensor(phecode_node_list))
phecode_map_df = node_map_df.loc[node_map_df['ntype'] == 'PHECODE']
phecode_map_df = phecode_map_df.assign(degree=phecode_node_degree)

In [23]:
phecode_map_df.head()

Unnamed: 0,ntype,node_name,node_id,global_graph_index,degree
0,PHECODE,Intestinal infection,8.0,24217,457
1,PHECODE,Bacterial enteritis,8.5,24218,555
2,PHECODE,Intestinal infection due to C. difficile,8.52,24220,281
3,PHECODE,Viral Enteritis,8.6,24221,348
4,PHECODE,Tuberculosis,10.0,24223,1473


In [25]:
phecode_map_df.tail()

Unnamed: 0,ntype,node_name,node_id,global_graph_index,degree
21796,PHECODE,Bullous dermatoses,695.2,25589,52
21797,PHECODE,Benign neoplasm of brain and other parts of ne...,225.0,24419,6
21798,PHECODE,Abnormal Papanicolaou smear of cervix and cerv...,792.0,25856,1
21799,PHECODE,Occlusion of cerebral arteries,433.2,25037,98
21800,PHECODE,Non-autoimmune hemolytic anemias,283.2,24612,2


In [26]:
degree_thresh = 1000
high_degree_df = phecode_map_df.loc[(phecode_map_df['degree'] >= 1000)]
icd_inds = node_map_df.loc[node_map_df['ntype'] == 'ICD9CM']['global_graph_index'].values.tolist()
keep_icd_inds = node_map_df.loc[(node_map_df['ntype'] == 'ICD9CM') & (node_map_df['node_id'].str.len() <= 5)]['global_graph_index'].values.tolist()

In [30]:
high_degree_df.head(30)

Unnamed: 0,ntype,node_name,node_id,global_graph_index,degree
4,PHECODE,Tuberculosis,10.0,24223,1473
5,PHECODE,Septicemia,38.0,24226,1122
7,PHECODE,Bacteremia,38.3,24229,1021
34,PHECODE,Other tests,1010.0,24268,1570
132,PHECODE,"Lymphoid leukemia, chronic",204.12,24393,1064
322,PHECODE,Delirium due to conditions classified elsewhere,290.2,24661,1091
323,PHECODE,Other persistent mental disorders due to condi...,290.3,24662,1089
324,PHECODE,Specific nonpsychotic mental disorders due to ...,291.4,24665,1188
329,PHECODE,Memory loss,292.3,24672,1621
340,PHECODE,Suicide or self-inflicted injury,297.2,24688,1019


In [28]:
high_degree_df.shape

(64, 5)

In [79]:
len(keep_icd_inds)

6884

In [80]:
from tqdm import tqdm 

edge_tuples = homo_hg.edges()

# for each high degree node, cut all icd neighbors. Get eid for forward/backward edges
u_remove_list = []
v_remove_list = []

for src_node in tqdm(high_degree_df['global_graph_index'].values):
    for u,v in (zip(edge_tuples[0], edge_tuples[1])):
        if (u.item() == src_node) and (v.item() in icd_inds):
            # check if has >4 digits
            if v.item() not in keep_icd_inds:
                u_remove_list.append(u)
                v_remove_list.append(v)

                # reverse edge
                u_remove_list.append(v)
                v_remove_list.append(u)

100%|██████████| 64/64 [04:32<00:00,  4.26s/it]


In [82]:
len(u_remove_list)

76874

In [84]:
high_degree_df.head()

Unnamed: 0,ntype,node_name,node_id,global_graph_index,degree
4,PHECODE,Tuberculosis,10.0,24223,1473
5,PHECODE,Septicemia,38.0,24226,1122
7,PHECODE,Bacteremia,38.3,24229,1021
34,PHECODE,Other tests,1010.0,24268,1570
132,PHECODE,"Lymphoid leukemia, chronic",204.12,24393,1064


In [85]:
remove_eids = homo_hg.edge_ids(torch.tensor(u_remove_list, dtype=torch.int64), torch.tensor(v_remove_list, dtype=torch.int64))
filter_homo_hg = dgl.remove_edges(homo_hg, remove_eids)

In [86]:
# add random features
n_feat=128
xavier_tensor = torch.empty(filter_homo_hg.num_nodes(), n_feat)
torch.nn.init.xavier_uniform_(xavier_tensor)
filter_homo_hg.ndata['feat'] = xavier_tensor

In [87]:
print(filter_homo_hg.num_edges())
print(filter_homo_hg.num_nodes())

1315610
67124


In [88]:
dgl.save_graphs("new_homo_hg_hms.pt", [filter_homo_hg])