<a href="https://colab.research.google.com/github/johnbiggan/CS416/blob/main/Copy_of_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Project notebook available in Google Colab space: https://colab.research.google.com/github/johnbiggan/DeepLearningforHealthcare-Project/blob/main/Code/Project.ipynb

Public GitHub repo: https://github.com/johnbiggan/DeepLearningforHealthcare-Project/blob/main/Code/Project.ipynb

and video presentation: [INSERT HTML HERE]

# Domain Knowledge Guided Deep Learning
###with Electronic Health Records Replication

# Introduction

Every year, more than a quarter of all healthcare spending in the US is spent on preventable diseases [1]. In an effort to intervene earlier researchers look to the ever-growing pool of electronic health records (EHRs) to predict disease risk. Many models have been developed, but with the recurrent nature of healthcare visits recurrent neural networks (RNNs) have shown the most promise.

Unfortunately, early RNNs treated the time between consecutive visits as well as the time between visits and disease diagnosis as unimportant. Additionally, other early models fail to incorporate medical knowledge in the form of relationships between diagnosis (e.g comorbid, causes, caused-by).

The Domain Knowledge Guided Recurrent Neural Network (DG-RNN) being replicated in this project addresses both shortcomings and has been found to improve heart failure prediction over-and-above previous methodologies [2]. In their paper, the authors described a model that uses multiple long short-term memory models [3] with attention to utilize disease diagnoses, procedure codes, visit time differences, and a medical knowledge graph (Figure 1). Incorporating medical knowledge graphs provides challenges. For example, the semi-structured nature of graphs adds to the complexity. Moreover, readily-available medical knowledge graphs have been few and far between.

Testing this model on the MIMIC-III dataset [4,5,6], it was able to outperform numerous other models, including random forests, support vector machines, and traditional long short-term memory RNNs on a heart failure prediction task.


![Figure 1](https://raw.githubusercontent.com/johnbiggan/DeepLearningforHealthcare-Project/main/Figures/Figure1.png "Figure 1")

Figure 1. DG-RNN model.

# Scope of Reproducibility

This replication will focus on reproducing the main findings of the paper. Specifically, the following hypotheses will be tested:

Hypothesis 1: Including the time between visits will improve the model over the base model that treats all visits in a simple sequential manner.

Hypothesis 2: Including domain knowledge, along with the time between visits, will lead to a better performing model than the base model.

# Methodology
Python version 3.11 was used.

Additionally, this project uses tools from the pyhealth, numpy, pytorch, and pandas packages, which should be installed if they have not already been.

In [None]:
# Install dependencies if not already installed
! pip install pyhealth
! pip install numpy
! pip install torch torchvision torchaudio
! pip install pandas

In [None]:
# Set seeds for reproducibility
import numpy as np
import torch

np.random.seed(1983)
torch.manual_seed(1983)

##  Data

### MIMIC-III

For their model, the authors used a proprietary dataset as well as the MIMIC-III (Medical Information Mart for Intensive Care III) dataset [4,5,6], which is a freely available database for researchers consisting of deidentified health data from over 40,000 patients who stayed in critical care units of the Beth Israel Deaconess Medical Center between 2001 and 2012. It includes detailed demographics and visit-level healthcare data.

Although the data is freely available to researchers, it is not publically available due to privacy concerns. Fortunately, in recent years pyhealth [7,8] contributors have created a synthetic dataset based on the MIMIC-III data that does not create the same privacy concerns. As this is a public-facing project, the synthetic MIMIC-III data from pyhealth will be used. However, the same analyses may be conducted with the original MIMIC-III dataset by changing the **root** argument to point at the original dataset.

In [None]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader

In [None]:
base_dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
    code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC"},
    dev=False,
    refresh_cache=False,
)
base_dataset.stat()

### PrimeKG

For the knowledge graph, the authors originally used the KnowLife knowledge graph [9]. Unfortunately, at the time of this replication the KnowLife knowledge graph was no longer available.

More recently, researchers have developed an updated medical knowledge graph called PrimeKG [10,11]. This appears to be a worthwhile alternative. One challenge was that the nodes are coded in Monarch Disease Ontology (MONDO) coding, for which there is not a simple mapping available to convert to ICD-9 or CCSCM codes, which is what the MIMIC-III dataset uses.

To overcome this, a subset of the diseases was created by including only heart-related conditions and conditions with a connection to heart-related conditions. Specifically, only disease descriptions including the key words (heart, atrial, ventriclular, cardia, myocardial, and coronary) were kept. This resulted in 576 diseases to manually code with ICD-9 codes.

Due to the fact that ICD-9 codes are not as granular as MONDO and other newer coding schemas (e.g. ICD-10), there were redundancies that meant that the final list of ICD-9 codes was reduced to 91 unique codes. This knowledge graph was then converted to CCSCM codes, which are typically used in pyhealth. This further reduced the total number of codes to 35 unique codes.

Unfortunately, all of this conversion resulted in a loss of granularity, which may have worsened the performance of the final model. However, this still allowed for a comparison of models with and without some information via a knowledge graph.

In [None]:
# Import knowledge graph
import pandas as pd

# Load the CSV data into a pandas DataFrame
url = "https://raw.githubusercontent.com/johnbiggan/DeepLearningforHealthcare-Project/main/Data/icd_graph_structure.csv"
kg = pd.read_csv(url)

# Display the first few rows of the DataFrame to understand its structure
kg.head()

In [None]:
# Convert ICD-9 to CCSCM for knowledge graph
from pyhealth.medcode import CrossMap

mapping = CrossMap.load(source_vocabulary="ICD9CM", target_vocabulary="CCSCM")

kg_ccscm = {}

# Iterate through the DataFrame and convert from ICD-9 to CCSCM
for index, row in kg.iterrows():
    ccscm_id = mapping.map(row['x_icd_id'])
    ccscm_set = set()

    if ccscm_id:
        ccscm_id = ccscm_id[0]

        for y_icd_id in eval(row['y_icd_id']):
            y_ccscm_id = mapping.map(y_icd_id)
            if y_icd_id != "" and y_ccscm_id != ccscm_id and y_ccscm_id:
                ccscm_set.add(y_ccscm_id[0])

        kg_ccscm[ccscm_id] = ccscm_set

print("The CCSCM coded graph contains", len(kg_ccscm), "primary codes.")

In [None]:
# Confirm the CCSCM code for conjestive heart failure to be used in the labeling task
from pyhealth.medcode import InnerMap

mapping = CrossMap.load(source_vocabulary="ICD9CM", target_vocabulary="CCSCM")
print("The CCSCM code that maps to ICD-9 code 428.0 (conjestive heart failure) is", mapping.map("428.0"))
print("The CCSCM code that maps to ICD-9 code 428.9 (conjestive heart failure) is", mapping.map("428.9"))

ccscm = InnerMap.load("CCSCM")
print("Confirmation that CCSCM code '108' corresponds to", ccscm.lookup("108"))

In [None]:
# Create custom visit time difference calculation and heart failure prediction task
from pyhealth.data import Patient, Visit
import numpy as np


def visit_time_diff_mimic3_fn(patient: Patient):
    """Processes a single patient for the visit time difference task.

    Visit time difference calculates the delay between the current visit and
    the previous visit.

    Args:
        patient: a Patient object

    Returns:
        samples: a list of samples, each sample is a dict with patient_id,
            visit_id, and other task-specific attributes as key

    Examples:
        >>> from pyhealth.datasets import MIMIC3Dataset
        >>> mimic3_base = MIMIC3Dataset(
        ...    root="/srv/local/data/physionet.org/files/mimiciii/1.4",
        ...    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
        ...    code_mapping={"ICD9CM": "CCSCM"},
        ... )
        >>> from pyhealth.tasks import hf_prediction_mimic3_fn
        >>> mimic3_sample = mimic3_base.set_task(visit_time_diff_mimic3_fn)
        >>> mimic3_sample.samples[0]
        [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['129', '157', '99', '101', '110', '55', '62', '105', '4', '63', '138']], 'related_conditions': [['96', '238', '108', '95', '47', '213', '104', '105']], 'procedures': [['1']], 'visit_diff': [[0.0, 0.0, 0.0, 0.0]] 'label': 0}]
    """
    samples = []
    criterion_time = None

    for i in range(len(patient) - 1):
        visit: Visit = patient[i]
        visit_time = visit.encounter_time

        next_visit: Visit = patient[i + 1]
        hf_label = 0

        if '108' not in next_visit.get_code_list(table="DIAGNOSES_ICD"):
            hf_label = 0
        else:
            hf_label = 1

        visit_diff = []

        if i > 0:
            prev_visit: Visit = patient[i - 1]
            prev_visit_time = prev_visit.encounter_time

            # Find criterion time (i.e. first encounter with HF diagnosis)
            for j in range(len(patient) - 1):
                c_visit: Visit = patient[j]
                if criterion_time == None and '108' in c_visit.get_code_list(table="DIAGNOSES_ICD"):
                    criterion_time = c_visit.encounter_time

            if criterion_time != None:
                v_c_s = np.sin(((visit_time - criterion_time).days)/10000).item()
                v_c_c = np.cos(((visit_time - criterion_time).days)/10000).item()
            else:
                v_c = 0.0
                v_c_s = 0.0
                v_c_c = 0.0

            v_p = visit_time - prev_visit_time

            visit_diff = [
                v_c_s,
                v_c_c,
                np.sin(((v_p).days)/10000).item(),
                np.cos(((v_p).days)/10000).item()
            ]

        else:
            visit_diff = [0.0, 0.0, 0.0, 0.0]

        related_conditions = set()

        for ccscm_id, ccscm_values in kg_ccscm.items():
            if ccscm_id in visit.get_code_list(table="DIAGNOSES_ICD"):
                for related_id in ccscm_values:
                    related_conditions.add(related_id)

        if not related_conditions:
            related_conditions.add('9999') # add dummy code if no related conditions

        conditions = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        # exclude: visits without condition and procedure code
        if len(conditions) * len(procedures) == 0 or len(patient) < 2:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": [conditions],
                "related_conditions": [list(related_conditions)],
                "procedures": [procedures],
                "visit_diff": [visit_diff],
                "label": hf_label,
            }
        )
    # no cohort selection
    return samples

In [None]:
# Create a sample dataset with a label to indicate an HF diagnosis and visit difference vector
sample_dataset = base_dataset.set_task(visit_time_diff_mimic3_fn)
sample_dataset.stat()

# Split dataset into training, validation, and testing using an 80%, 10%, 10% split
train_dataset, val_dataset, test_dataset = split_by_patient(
    sample_dataset, [0.8, 0.1, 0.1]
)

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# View a single patient
print(sample_dataset[9]) #9

##   DG-RNN Model
The domain-guided recurrent neural network (DG-RNN) uses multiple long short-term memory units that incorporate both length between visits and domain knowledge from a knowledge graph of related conditions (see Figure 1).
  * Model architecture [2]:
    * **Embedding Layer:** Converts medical event codes into dense vectors, capturing semantic similarities among codes.
    * **LSTM Layer:** Utilizes Long Short-Term Memory units to model sequences of medical events. The LSTM handles variable-length input sequences, crucial for EHR data where each patient's record may differ in length.
    * **Knowledge Graph Attention Mechanism:** Dynamically integrates external domain knowledge (from a medical knowledge graph) into the LSTM outputs at each time step. This layer selectively enhances the LSTM output with information relevant to the current patient state.
    * **Global Max Pooling Layer:** Aggregates all LSTM outputs across the time dimension, to summarize the entire input sequence effectively.
    * **Fully Connected Layer:** A dense layer that transforms the pooled LSTM outputs into final prediction logits.
    * **Output Activation Function:** Sigmoid

&dagger;Original paper repo is available at https://github.com/AIMedLab/DG-RNN/tree/master.

In [None]:
from pyhealth.models import RNN
from pyhealth.trainer import Trainer
import torch
import torch.optim as optim

In [None]:
# Create simple LSTM without knowledge graph or time
base_model = RNN(
    dataset=sample_dataset,
    feature_keys=["conditions", "procedures"],
    label_key="label",
    mode="binary",
    rnn_type="LSTM",
    embedding_dim=128,
    hidden_dim=128,
)

In [None]:
# Create simple LSTM without knowledge graph
time_model = RNN(
    dataset=sample_dataset,
    feature_keys=["conditions", "procedures", "visit_diff"],
    label_key="label",
    mode="binary",
    rnn_type="LSTM",
    embedding_dim=128,
    hidden_dim=128,
)

In [None]:
# Create full model
full_gn_rnn_model = RNN(
    dataset=sample_dataset,
    feature_keys=["conditions", "related_conditions", "procedures", "visit_diff"],
    label_key="label",
    mode="binary",
    rnn_type="LSTM",
    embedding_dim=128,
    hidden_dim=128,
)

##   Training
Training was completed using 80% of the MIMIC-III dataset, randomly sampled. A 10% validation set was also used.
  * Hyperparameters:
    * **Loss Function:** Binary Cross Entropy
    * **Optimizer:** Adam
    * **Learning Rate:** 0.0001
  * Computational Requirements:
    * **Type of Hardware:** CPU
    * **Average Runtime per Epoch:**
      * Base model = .73 seconds
      * Model with time = 1.03 seconds
      * Full model = 1.34 seconds
    * **Number of Training Epochs:** 50

In [None]:
# Simple model training
from datetime import datetime

start_time = datetime.now()

base_trainer = Trainer(model=base_model)
base_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    optimizer_class=torch.optim.Adam,
    optimizer_params={"lr": 1e-4},
    monitor="roc_auc",
)

end_time = datetime.now()
print(f"Training started at: {start_time}")
print(f"Training ended at: {end_time}")
print(f"Total elapsed time: {end_time - start_time}")
print(f"Average epoch time: {(end_time - start_time)/50}")

In [None]:
# Simple model w/time training

start_time = datetime.now()

time_trainer = Trainer(model=time_model)
time_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    optimizer_class=torch.optim.Adam,
    optimizer_params={"lr": 1e-4},
    monitor="roc_auc",
)

end_time = datetime.now()
print(f"Training started at: {start_time}")
print(f"Training ended at: {end_time}")
print(f"Total elapsed time: {end_time - start_time}")
print(f"Average epoch time: {(end_time - start_time)/50}")

In [None]:
# Full model training

start_time = datetime.now()

full_gn_rnn_trainer = Trainer(model=full_gn_rnn_model)
full_gn_rnn_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    optimizer_class=torch.optim.Adam,
    optimizer_params={"lr": 1e-4},
    monitor="roc_auc",
)

end_time = datetime.now()
print(f"Training started at: {start_time}")
print(f"Training ended at: {end_time}")
print(f"Total elapsed time: {end_time - start_time}")
print(f"Average epoch time: {(end_time - start_time)/50}")

## Evaluation

All three models were evaluated using a test set of 10% of the original dataset that was held out for testing. The metric of interest was the area under the receiver operating characteristic curve (AUROC), which provides a measure of model performance in binary classification tasks where scores range from 0-1 and higher scores indicate better model performance. Note that all three models were tested on the same set.

In [None]:
# Simple DG-RNN model performance

print(base_trainer.evaluate(test_dataloader))

In [None]:
# Simple DG-RNN model performance with time added

print(time_trainer.evaluate(test_dataloader))

In [None]:
# Full DG-RNN model performance

print(full_gn_rnn_trainer.evaluate(test_dataloader))

# Results

Similar to the models presented in the original paper, more information in the model led to improved overall performance (Table 1). For instance, the full model (AUROC = 0.5790) outperformed the partial model that included the time component, but not the knowledge graph information (AUROC = 0.5647), which outperformed the base model with no time or knowledge graph information (AUROC = 0.5629).

These findings support both hypotheses, such that with each additional piece of information, the model improved in its predictive performance.

*Table 1.* Model AUROC values.

| Full Model | Model w/o KG | Simple Model w/o Time or KG |
|----------|----------|----------|
|    0.5790     |    0.5647     |    0.5629     |


## Model comparison

The authors reported an AUROC of 0.7375 for the full DG-RNN model on the MIMIC-III dataset [2]. They also report lowered performance (AUROC = 0.7238) for the simpler model without the knowledge graph included.

For this replication, there was a similar pattern, but the overall AUROC was considerably lower across the board. The possible reasons for the lowered performance is discussed below.

The full model that included both the time component and the additional information from the knowledge graph outperformed (AUROC = 0.5790) the model with only the time component (AUROC = 0.5647), which outperformed the simplest model with no time or knowledge graph information (AUROC = 0.5629).

While the predictive ability was lower overall for the replication models, they display the same pattern as the original results. This replicates the finding that adding further knowledge of interconnections can improve model performance and enhance predicitve ability.

# Discussion

This model supports the overarching claim from the original authors that adding timing information and knowledge about disease relationships can enhance an RNN over-and-above the base model in a heart failure prediction task [2].

In this replication pyhealth [7,8] was used where possible as it simplifies data pre-processing, model building/training, and model evaluation. In addition to simplifying the process, the resulting code is more human-readable. Overall, using pyhealth has made the reproduction easier than it would have been otherwise and more portable to show to less technical audiences.

The largest challenge faced had to do with the knowledge graph. The one that the original authors used [9] no longer appears to be available and the other knowledge graphs that were explored had their own challenges. For instance, the knowledge graph, PrimeKG [10], that ended up being used was extensive, but used a medical coding style, MONDO, that was difficult to translate to ICD-9 and CCSCM codes.

In order to make this knowledge graph work with the MIMIC-III dataset, the diagnostic codes needed to be manually coded into an ICD-9 format. Not only is this prone to human error, but as the original dataset consisted of 6,392 individual nodes and numerous connections, the dataset had to be focused to only disease-related codes and only those that were related on the heart. Between this reduction, the conversion to the less granular ICD-9 codes, and the further conversion to even less granular CCSCM codes, a lot of information that may have been predictive of heart failure at the next visit was lost. This likely led, in part, to the lower AUROC score for the model.

Another possible reason for the reduced performance relative to what was reported in the paper, may have to do with the fact that the datasets used were not exactly the same. Since this is public, the synthesized version of the MIMIC-III dataset [8] was used instead of the original MIMIC-III dataset. Moreover, the original article also used data from a proprietary source that was not available for this replication. Those two factors may also have figured into the differences in overall magnitude of the model performance. Future replication that is not public facing may be improved with the use of the original MIMIC-III and proprietary datasets.

That being said, the replication showed the same pattern of results as the original article. This is important in that it demonstrates that visit timing and enhanced knowledge base can be used to enhance the predictive quality of a model and may allow physicians to intervene sooner, resulting in better outcomes for patients. Additionally, this replication was accomplished using a new package, pyhealth, that could make this type of model simpler to implement in clincal settings. Future research should expand on this by generalizing to other diseases.

# References

1.   Bolnick, H. J., Bui, A. L., Bulchis, A., Chen, C., Chapin, A., Lomsadze, L., ... Dieleman, J. L. (2020). Health-care spending attributable to modifiable risk factors in the USA: An economic attribution analysis. The Lancet Public Health, 5(10), e525-e535. https://doi.org/10.1016/S2468-2667(20)30203-6
2.   Yin, C., Zhao, R., Qian, B., Lv, X., & Zhang, P. (2019). Domain Knowledge Guided Deep Learning with Electronic Health Records. In 2019 IEEE International Conference on Data Mining (ICDM) (pp. 738-747). Beijing, China. https://doi.org/10.1109/ICDM.2019.00084
3.   Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780. https://doi.org/10.1162/neco.1997.9.8.1735
4.   Johnson, A., Pollard, T., & Mark, R. (2016). MIMIC-III Clinical Database (version 1.4). PhysioNet. https://doi.org/10.13026/C2XW26.
5.   Johnson, A. E. W., Pollard, T. J., Shen, L., Lehman, L. H., Feng, M., Ghassemi, M., Moody, B., Szolovits, P., Celi, L. A., & Mark, R. G. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data, 3, 160035.
6.   Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220.
7.   Yang, C., Wu, Z., Jiang, P., Lin, Z., Gao, J., Danek, B. P., & Sun, J. (2023). PyHealth: A Deep Learning Toolkit for Healthcare Applications. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (pp. 5788–5789). New York, NY, USA: Association for Computing Machinery.
8.   Theodorou, B., Xiao, C., & Sun, J. (2023). Synthesize high-dimensional longitudinal electronic health records via hierarchical autoregressive language model. Nature Communications, 14, 5305. https://doi.org/10.1038/s41467-023-41093-0
8.   Ernst, P., Siu, A., & Weikum, G. (2015). Knowlife: A versatile approach for constructing a large knowledge graph for biomedical sciences. BMC Bioinformatics.
9.   Chandak, P., Huang, K., & Zitnik, M. (2023). Building a knowledge graph to enable precision medicine. Scientific Data, 10(1), Article 67. https://doi.org/10.1038/s41597-023-01960-3
10.   Chandak, P. (2022). PrimeKG (V2) [Data set]. Harvard Dataverse. https://doi.org/10.7910/DVN/IXA7BM