<a href="https://colab.research.google.com/github/johnbiggan/DeepLearningforHealthcare-Project/blob/main/Code/Project_alt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Project notebook available in Google Colab space: https://colab.research.google.com/github/johnbiggan/DeepLearningforHealthcare-Project/blob/main/Code/Project.ipynb

and public GitHub repo: https://github.com/johnbiggan/DeepLearningforHealthcare-Project/blob/1da3ddf002c103638dbb566d9dffa9518eb24436/Code/Project.ipynb

# Domain Knowledge Guided Deep Learning
###with Electronic Health Records Replication

# Introduction

Every year, more than a quarter of all healthcare spending in the US is spent on preventable diseases [1]. In an effort to intervene earlier researchers look to the ever-growing pool of electronic health records (EHRs) to predict disease risk. Many models have been developed, but with the recurrent nature healthcare visits recurrent neural networks (RNNs) have shown the most promise.

Unfortunately, early RNNs treated the time between consecutive visits as well as the time between visits and disease diagnosis as unimportant. Additionally, other early models fail to incorporate medical knowledge in the form of relationships between diagnosis (e.g comorbid, causes, caused-by).

The Domain Knowledge Guided Recurrent Neural Network (DG-RNN) being replicated in this project addresses both shortcomings and has been found to improve heart failure prediction over-and-above previous methodologies [2]. In their paper, the authors described a model that uses multiple long short-term memory models [3] with attention to utilize disease diagnoses, procedure codes, visit time differences, and a medical knowledge graph (Figure 1). Incorporating medical knowledge graphs provides challenges. For example, the semi-structured nature of graphs adds to the complexity. Moreover, readily-available medical knowledge graphs have been few and far between.

Testing this model on the MIMIC-III dataset [4,5,6], it was able to outperform numerous other models, including random forests, support vector machines, and traditional long short-term memory RNNs on a heart failure prediction task.


![Figure 1](https://raw.githubusercontent.com/johnbiggan/DeepLearningforHealthcare-Project/main/Figures/Figure1.png "Figure 1")

Figure 1. DG-RNN model.

# Scope of Reproducibility:

This replication will focus on reproducing the main findings of the paper. Specifically, the following hypotheses will be tested:

Hypothesis 1: Including the time between visits will improve the model over the base model that treats all visits in a simple sequential manner.

Hypothesis 2: Including domain knowledge, along with the time between visits, will lead to a better performing model than the base model.

# Methodology

This project will use tools from the pyhealth package, which should be installed if it has not already been done.

In [None]:
! pip install pyhealth

##  Data

### MIMIC-III

For their model, the authors used a proprietary dataset as well as the MIMIC-III (Medical Information Mart for Intensive Care III) dataset [4,5,6], which is a freely available database for researchers consisting of deidentified health data from over 40,000 patients who stayed in critical care units of the Beth Israel Deaconess Medical Center between 2001 and 2012. It includes detailed demographics and visit-level healthcare data.

Although the data is freely available to researchers, it is not publically available due to privacy concerns. Fortunately, in recent years pyhealth [7] contributors have created a synthetic dataset based on the MIMIC-III data that does not create the same privacy concerns. As this is a public-facing project, the synthetic MIMIC-III data from pyhealth will be used. However, the same analyses may be conducted with the original MIMIC-III dataset by changing the **root** argument to point at the original dataset.

In [None]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import RNN
from pyhealth.models import RETAIN
from pyhealth.trainer import Trainer

In [None]:
base_dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
    code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC"},
    dev=False,
    refresh_cache=False,
)
base_dataset.stat()

In [None]:
# Confirm the CCSCM code for conjestive heart failure to be used in the labeling task
from pyhealth.medcode import CrossMap
from pyhealth.medcode import InnerMap

mapping = CrossMap.load(source_vocabulary="ICD9CM", target_vocabulary="CCSCM")
print("The CCSCM code that maps to ICD-9 code 428.0 (conjestive heart failure) is", mapping.map("428.0"))
print("The CCSCM code that maps to ICD-9 code 428.9 (conjestive heart failure) is", mapping.map("428.9"))

ccscm = InnerMap.load("CCSCM")
print("Confirmation that CCSCM code '108' corresponds to", ccscm.lookup("108"))

In [None]:
# Create custom visit time difference calculation and heart failure prediction task
from pyhealth.data import Patient, Visit
import numpy as np


def visit_time_diff_mimic3_fn(patient: Patient):
    """Processes a single patient for the visit time difference task.

    Visit time difference calculates the delay between the current visit and
    the previous visit.

    Args:
        patient: a Patient object

    Returns:
        samples: a list of samples, each sample is a dict with patient_id,
            visit_id, and other task-specific attributes as key

    Examples:
        >>> from pyhealth.datasets import MIMIC3Dataset
        >>> mimic3_base = MIMIC3Dataset(
        ...    root="/srv/local/data/physionet.org/files/mimiciii/1.4",
        ...    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
        ...    code_mapping={"ICD9CM": "CCSCM"},
        ... )
        >>> from pyhealth.tasks import hf_prediction_mimic3_fn
        >>> mimic3_sample = mimic3_base.set_task(visit_time_diff_mimic3_fn)
        >>> mimic3_sample.samples[0]
        [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'visit_diff': [[0.0, 0.0, 0.0, 0.0]] 'label': 0}]
    """
    samples = []
    criterion_time = None

    for i in range(len(patient) - 1):
        visit: Visit = patient[i]
        visit_time = visit.encounter_time

        next_visit: Visit = patient[i + 1]
        hf_label = 0

        if '108' not in next_visit.get_code_list(table="DIAGNOSES_ICD"):
            hf_label = 0
        else:
            hf_label = 1

        visit_diff = []

        if i > 0:
            prev_visit: Visit = patient[i - 1]
            prev_visit_time = prev_visit.encounter_time

            # Find criterion time (i.e. first encounter with HF diagnosis)
            for j in range(len(patient) - 1):
                c_visit: Visit = patient[j]
                if criterion_time == None and '108' in c_visit.get_code_list(table="DIAGNOSES_ICD"):
                    criterion_time = c_visit.encounter_time

            if criterion_time != None:
                v_c_s = np.sin(((visit_time - criterion_time).days)/10000).item()
                v_c_c = np.cos(((visit_time - criterion_time).days)/10000).item()
            else:
                v_c = 0.0
                v_c_s = 0.0
                v_c_c = 0.0

            v_p = visit_time - prev_visit_time

            visit_diff = [
                v_c_s,
                v_c_c,
                np.sin(((v_p).days)/10000).item(),
                np.cos(((v_p).days)/10000).item()
            ]

        else:
            visit_diff = [0.0, 0.0, 0.0, 0.0]

        conditions = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        # exclude: visits without condition and procedure code
        if len(conditions) * len(procedures) == 0 or len(patient) < 2:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": [conditions],
                "procedures": [procedures],
                "visit_diff": [visit_diff],
                "label": hf_label,
            }
        )
    # no cohort selection
    return samples

In [None]:
# Create a sample dataset with a label to indicate an HF diagnosis and visit difference vector
sample_dataset = base_dataset.set_task(visit_time_diff_mimic3_fn)
sample_dataset.stat()

# Split dataset into training, validation, and testing using an 80%, 10%, 10% split
train_dataset, val_dataset, test_dataset = split_by_patient(
    sample_dataset, [0.8, 0.1, 0.1]
)

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# View a single patient
print(sample_dataset[9]) #9

### PrimeKG

For the knowledge graph, the authors originally used the KnowLife knowledge graph [8]. Unfortunately, at the time of this replication the KnowLife knowledge graph was no longer available.

More recently, researchers have developed an updated medical knowledge graph called PrimeKG [9,10]. This appears to be a worthwhile alternative. At the time of this writing, this is being explored. One challenge is that the nodes are coded in Monarch Disease Ontology (MONDO) coding, for which there is not a simple mapping available.

In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import requests
import io

# Load the CSV data into a pandas DataFrame
url = "https://raw.githubusercontent.com/johnbiggan/DeepLearningforHealthcare-Project/main/Data/kg_grouped_diseases_bert_map.csv"
df = pd.read_csv(url)

# Display the first few rows of the DataFrame to understand its structure
df.head()


In [None]:
# Function to extract unique edges from the group_id_bert column
def extract_edges(grouped_ids):
    # Split the string by underscore and convert to set for unique ids
    ids = set(grouped_ids.split('_'))
    # Create a list of tuples representing edges (all possible combinations without repetition)
    edges = [(a, b) for idx, a in enumerate(ids) for b in list(ids)[idx + 1:]]
    return edges

# Initialize an empty graph
G = nx.Graph()

# Add nodes and edges to the graph
for _, row in df.iterrows():
    # Add the disease node to the graph
    G.add_node(row['node_id'], name=row['node_name'], type=row['node_type'], source=row['node_source'])

    # Add edges from this disease to others in the same group
    edges = extract_edges(row['group_id_bert'])
    G.add_edges_from(edges)

# Visualize the graph
plt.figure(figsize=(12, 12))
nx.draw(G, with_labels=False, node_size=20, alpha=0.6, edge_color="r", font_size=8)
plt.title("Graph of Disease Connections")
plt.show()

##   Model
The model includes the model definitation which usually is a class, model training, and other necessary parts.
  * Model architecture: layer number/size/type, activation function, etc
  * Training objectives: loss function, optimizer, weight of each loss term, etc
  * Others: whether the model is pretrained, Monte Carlo simulation for uncertainty analysis, etc
  * The code of model should have classes of the model, functions of model training, model validation, etc.
  * If your model training is done outside of this notebook, please upload the trained model here and develop a function to load and test it.

In [None]:
base_model = RNN(
    dataset=sample_dataset,
    feature_keys=["conditions", "procedures"],
    label_key="label",
    mode="binary",
)

In [None]:
base_trainer = Trainer(model=base_model)
base_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",
)

In [None]:
time_model = RNN(
    dataset=sample_dataset,
    feature_keys=["conditions", "procedures", "visit_diff"],
    label_key="label",
    mode="binary",
)

In [None]:
time_trainer = Trainer(model=time_model)
time_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",
)

In [None]:
retain_model = RETAIN(
    dataset=sample_dataset,
    feature_keys=["conditions", "procedures"],
    label_key="label",
    mode="binary",
)

In [None]:
retain_trainer = Trainer(model=retain_model)
retain_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",
)

In [None]:
class my_model():
  # use this class to define your model
  pass

model = my_model()
loss_func = None
optimizer = None

def train_model_one_iter(model, loss_func, optimizer):
  pass

num_epoch = 10
# model training loop: it is better to print the training/validation losses during the training
for i in range(num_epoch):
  train_model_one_iter(model, loss_func, optimizer)
  train_loss, valid_loss = None, None
  print("Train Loss: %.2f, Validation Loss: %.2f" % (train_loss, valid_loss))


# Results
In this section, you should finish training your model training or loading your trained model. That is a great experiment! You should share the results with others with necessary metrics and figures.

Please test and report results for all experiments that you run with:

*   specific numbers (accuracy, AUC, RMSE, etc)
*   figures (loss shrinkage, outputs from GAN, annotation or label of sample pictures, etc)


In [None]:
# metrics to evaluate my model

# plot figures to better show the results

# it is better to save the numbers and figures for your presentation.

In [None]:
# Base model performance

print(base_trainer.evaluate(test_dataloader))

In [None]:
# Time-enhanced model performance

print(time_trainer.evaluate(test_dataloader))

In [None]:
# RETAIN model performance

print(retain_trainer.evaluate(test_dataloader))

In [None]:
# Full DG-RNN model performance

#print(gn_rnn_trainer.evaluate(test_dataloader))

## Model comparison

In [None]:
# compare you model with others
# you don't need to re-run all other experiments, instead, you can directly refer the metrics/numbers in the paper

# Discussion

In this section,you should discuss your work and make future plan. The discussion should address the following questions:
  * Make assessment that the paper is reproducible or not.
  * Explain why it is not reproducible if your results are kind negative.
  * Describe “What was easy” and “What was difficult” during the reproduction.
  * Make suggestions to the author or other reproducers on how to improve the reproducibility.
  * What will you do in next phase.



# References

1.   Bolnick, H. J., Bui, A. L., Bulchis, A., Chen, C., Chapin, A., Lomsadze, L., ... Dieleman, J. L. (2020). Health-care spending attributable to modifiable risk factors in the USA: An economic attribution analysis. The Lancet Public Health, 5(10), e525-e535. https://doi.org/10.1016/S2468-2667(20)30203-6
2.   Yin, C., Zhao, R., Qian, B., Lv, X., & Zhang, P. (2019). Domain Knowledge Guided Deep Learning with Electronic Health Records. In 2019 IEEE International Conference on Data Mining (ICDM) (pp. 738-747). Beijing, China. https://doi.org/10.1109/ICDM.2019.00084
3.   Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780. https://doi.org/10.1162/neco.1997.9.8.1735
4.   Johnson, A., Pollard, T., & Mark, R. (2016). MIMIC-III Clinical Database (version 1.4). PhysioNet. https://doi.org/10.13026/C2XW26.
5.   Johnson, A. E. W., Pollard, T. J., Shen, L., Lehman, L. H., Feng, M., Ghassemi, M., Moody, B., Szolovits, P., Celi, L. A., & Mark, R. G. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data, 3, 160035.
6.   Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220.
7.   Yang, C., Wu, Z., Jiang, P., Lin, Z., Gao, J., Danek, B. P., & Sun, J. (2023). PyHealth: A Deep Learning Toolkit for Healthcare Applications. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (pp. 5788–5789). New York, NY, USA: Association for Computing Machinery.
8.   Ernst, P., Siu, A., & Weikum, G. (2015). Knowlife: A versatile approach for constructing a large knowledge graph for biomedical sciences. BMC Bioinformatics.
9.   Chandak, P., Huang, K., & Zitnik, M. (2023). Building a knowledge graph to enable precision medicine. Scientific Data, 10(1), Article 67. https://doi.org/10.1038/s41597-023-01960-3
10.   Chandak, P. (2022). PrimeKG (V2) [Data set]. Harvard Dataverse. https://doi.org/10.7910/DVN/IXA7BM