Library imports and data preparation

In [70]:
# ! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.data import Patient, Visit, Event
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.medcode import InnerMap
from tqdm import tqdm
from pyhealth.metrics.binary import binary_metrics_fn
from sklearn.metrics import precision_score
import numpy as np
import pandas as pd
import torch
from typing import List, Dict

# Set this to the directory with all MIMIC-3 dataset files
data_root = "data"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=False
)

In [3]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n'

In [4]:
# Find all diagnoses codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diagnosis_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    all_diagnosis_codes.extend(conditions)

codes = pd.Series(all_diagnosis_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
n_unique_diag_codes = len(filtered_diag_codes)

In [5]:
MIN_N_VISITS_PER_PATIENT = 2

# Filter Dataset to requirements specified in paper

filtered_patients = {}
for patient_id, patient in tqdm(mimic3_ds.patients.items()):

    filtered_patient: Patient = Patient(
        patient_id=patient.patient_id,
        birth_datetime=patient.birth_datetime,
        death_datetime=patient.death_datetime,
        gender=patient.gender,
        ethnicity=patient.ethnicity
    )

    for i_visit, visit in enumerate(patient):
        filtered_visit: Visit = Visit(
            visit_id=visit.visit_id,
            patient_id=visit.patient_id,
            encounter_time=visit.encounter_time,
            discharge_time=visit.discharge_time,
            discharge_status=visit.discharge_status
        )

        diagnoses_codes = visit.get_code_list("DIAGNOSES_ICD")
        procedures_codes = visit.get_code_list("PROCEDURES_ICD")
        prescriptions_codes = visit.get_code_list("PRESCRIPTIONS")

        if len(diagnoses_codes) > 0:
            diagnosis_events = visit.event_list_dict["DIAGNOSES_ICD"]
            for i_event in range(len(diagnosis_events) - 1, -1, -1):
                event: Event = diagnosis_events[i_event]
                if event.code not in filtered_diag_codes:
                    diagnosis_events.pop(i_event) # Remove the diagnosis code with fewer than the cutoff occurrences

            if len(diagnosis_events) == 0: continue # Don't include visits with no diagnoses

            filtered_visit.set_event_list("DIAGNOSES_ICD", diagnosis_events)
        else:
            continue # Don't include visits with no diagnoses

        if len(procedures_codes) > 0:
           filtered_visit.set_event_list("PROCEDURES_ICD", visit.event_list_dict["PROCEDURES_ICD"])

        if len(prescriptions_codes) > 0:
            filtered_visit.set_event_list("PRESCRIPTIONS", visit.event_list_dict["PRESCRIPTIONS"])

        filtered_patient.add_visit(filtered_visit)

    if len(filtered_patient.visits) >= MIN_N_VISITS_PER_PATIENT:
        filtered_patients[patient_id] = filtered_patient


100%|██████████| 46520/46520 [01:12<00:00, 644.27it/s]


In [7]:
mimic3_ds.patients = filtered_patients
mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 7496
	- Number of visits: 19905
	- Number of visits per patient: 2.6554
	- Number of events per visit in DIAGNOSES_ICD: 12.9735
	- Number of events per visit in PROCEDURES_ICD: 4.0975



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 7496\n\t- Number of visits: 19905\n\t- Number of visits per patient: 2.6554\n\t- Number of events per visit in DIAGNOSES_ICD: 12.9735\n\t- Number of events per visit in PROCEDURES_ICD: 4.0975\n'

In [77]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

icd9cm = InnerMap.load("ICD9CM")

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window: int = 30, max_length_visits: int = None):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # Clip the patient visits to the most recent max_length_visits + 1 if max_length_visits is not None
    if max_length_visits is not None:
        n_visits = len(sorted_visits)
        if n_visits > max_length_visits + 1:
            sorted_visits = sorted_visits[n_visits - (max_length_visits + 1):]

    feature_visits: List[Visit] = sorted_visits[:-1]
    last_visit: Visit = sorted_visits[-1]
    second_to_last_visit: Visit = feature_visits[-1]
    first_visit: Visit = feature_visits[0]

    # step 1 a: define readmission label
    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff <= time_window else 0

    # step 1 b: define diagnosis prediction label
    diagnosis_label = list(set([icd9cm.get_ancestors(code)[1] for code in last_visit.get_code_list("DIAGNOSES_ICD")]))

    # step 2: obtain features
    visits_diagnoses = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(feature_visits):
        diagnoses = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        # Exclude visits that are missing either diagnoses or procedures.
        # BiteNet can handle missing procedures, but other PyHealth models like RNN
        # require all features have a length greater than 0.
        if len(diagnoses) == 0:
            continue

        visits_diagnoses.append(diagnoses)
        visits_procedures.append(procedures)
        visits_intervals.append([str(time_diff_from_first_visit)])

    unique_diagnoses = list(set(flatten(visits_diagnoses)))

    # step 3: exclusion criteria
    if len(unique_diagnoses) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "diagnoses": visits_diagnoses,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "readmission_label": readmission_label,
            "diagnosis_label": diagnosis_label
        }
    )
    return samples

In [88]:
BATCH_SIZE = 32
KS = list(range(5, 31, 5))
N_TRIALS=5
SEQ_LENS = list(range(6, 17, 2))
RESULTS_FILE = "results2.csv"

In [89]:
from sklearn.metrics import confusion_matrix
def train_and_inference(model, train_loader, val_loader, test_loader, lr=0.001, monitor="pr_auc", optim = torch.optim.RMSprop):
    trainer = Trainer(model=model, device=device)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=10,
        monitor=monitor,
        optimizer_class=optim,
        optimizer_params = {"lr" : lr},
        load_best_model_at_last=False
    )

    return trainer.inference(test_loader)

def precision_at_k(y_true: np.ndarray, y_prob: np.ndarray):

    y_pred: np.ndarray = (y_prob > 0.5).astype(int)
    desc_idx: np.ndarray = np.flip(np.argsort(y_prob, axis=-1), axis=-1)

    y_true = np.take(y_true, desc_idx).astype(float)
    y_pred = np.take(y_pred, desc_idx).astype(float)

    precisions: List[float] = []
    for k in KS:
        precisions.append(
            precision_score(y_true[:, :k], y_pred[:, :k], average='samples', labels=[0,1])
        )

    precisions: Dict[str, float] = {
        f"precision@{k}": p for k, p in zip(KS, precisions)
    }
    return precisions

In [90]:
# DxTx with interval embeddings

metrics_df = pd.DataFrame(columns=['model_name', 'feature_set', 'seq_len', 'trial', 'pr_auc', 'roc_auc', 'f1'] + [f"precision@{k}" for k in KS])

def train_and_record_metrics(model_readm, model_diag, model_name, feature_set, seq_len, train_loader, val_loader, test_loader, lr=0.001, trial=None):
    global metrics_df

    y_true, y_prob, _ = train_and_inference(
        model_readm,
        train_loader,
        val_loader,
        test_loader,
        lr=lr
    )
    binary_metrics = binary_metrics_fn(y_true, y_prob, metrics=["pr_auc", "roc_auc", "f1"])

    y_true, y_prob, _ = train_and_inference(
        model_diag,
        train_loader,
        val_loader,
        test_loader,
        lr=lr,
        monitor="pr_auc_samples",
        optim=torch.optim.RMSprop
    )
    precisions = precision_at_k(y_true, y_prob)

    row = binary_metrics | precisions | {
        "model_name": model_name,
        "feature_set": feature_set,
        "seq_len": seq_len,
        "trial": trial
    }
    row = {
        k: [v] for k, v in row.items()
    }

    metrics_df = pd.concat([metrics_df, pd.DataFrame.from_dict(row)], ignore_index=True)

    # Save df for checkpoint
    metrics_df.to_csv(RESULTS_FILE, index=False)

In [None]:
%%capture
from bitenet import BiteNet
from baseline import RNN, BRNN, RETAIN, Deepr

for seq_len in SEQ_LENS:

    dataset = mimic3_ds.set_task(
        task_fn=lambda p: patient_level_readmission_prediction(p, max_length_visits=seq_len)
    )

    for trial in range(N_TRIALS):

        train, val, test = split_by_patient(dataset, [0.8, 0.1, 0.1])

        train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

        #################### BITENET ####################
        train_and_record_metrics(
            model_readm=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures", "intervals"],
                label_key = "readmission_label",
                mode = "binary",
            ).to(device),
            model_diag=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures", "intervals"],
                label_key = "diagnosis_label",
                mode = "multilabel",
            ).to(device),
            model_name="bitenet",
            feature_set="dxtx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            lr=0.0005,
            trial=trial
        )

        train_and_record_metrics(
            model_readm=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "intervals"],
                label_key = "readmission_label",
                mode = "binary",
            ).to(device),
            model_diag=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "intervals"],
                label_key = "diagnosis_label",
                mode = "multilabel",
            ).to(device),
            model_name="bitenet",
            feature_set="dx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            lr=0.0005,
            trial=trial
        )

        #################### RNN ####################
        train_and_record_metrics(
                model_readm=RNN(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "readmission_label",
                mode = "binary",
            ).to(device),
            model_diag=RNN(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "diagnosis_label",
                mode = "multilabel",
            ).to(device),
            model_name="rnn",
            feature_set="dxtx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        train_and_record_metrics(
                model_readm=RNN(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "readmission_label",
                mode = "binary",
            ).to(device),
            model_diag=RNN(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "diagnosis_label",
                mode = "multilabel",
            ).to(device),
            model_name="rnn",
            feature_set="dx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        #################### BRNN ####################
        train_and_record_metrics(
                model_readm=BRNN(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "readmission_label",
                mode = "binary",
                bidirectional=True
            ).to(device),
            model_diag=RNN(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "diagnosis_label",
                mode = "multilabel",
                bidirectional=True
            ).to(device),
            model_name="brnn",
            feature_set="dxtx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        train_and_record_metrics(
                model_readm=BRNN(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "readmission_label",
                mode = "binary",
                bidirectional=True
            ).to(device),
            model_diag=RNN(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "diagnosis_label",
                mode = "multilabel",
                bidirectional=True
            ).to(device),
            model_name="brnn",
            feature_set="dx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        #################### RETAIN ####################
        train_and_record_metrics(
                model_readm=RETAIN(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "readmission_label",
                mode = "binary"
            ).to(device),
            model_diag=RETAIN(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "diagnosis_label",
                mode = "multilabel"
            ).to(device),
            model_name="retain",
            feature_set="dxtx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        train_and_record_metrics(
                model_readm=RETAIN(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "readmission_label",
                mode = "binary"
            ).to(device),
            model_diag=RETAIN(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "diagnosis_label",
                mode = "multilabel"
            ).to(device),
            model_name="retain",
            feature_set="dx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        #################### Deepr ####################
        train_and_record_metrics(
                model_readm=Deepr(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "readmission_label",
                mode = "binary"
            ).to(device),
            model_diag=Deepr(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures"],
                label_key = "diagnosis_label",
                mode = "multilabel"
            ).to(device),
            model_name="deepr",
            feature_set="dxtx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

        train_and_record_metrics(
                model_readm=Deepr(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "readmission_label",
                mode = "binary"
            ).to(device),
            model_diag=Deepr(
                dataset = dataset,
                feature_keys = ["diagnoses"],
                label_key = "diagnosis_label",
                mode = "multilabel"
            ).to(device),
            model_name="deepr",
            feature_set="dx",
            seq_len=seq_len,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            trial=trial
        )

BiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3428, 128, padding_idx=0)
    (procedures): Embedding(1358, 128, padding_idx=0)
    (intervals): Embedding(1649, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): _BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): PrePostProcessingWrapper(
          (module): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=False)
            (k_linear): Linear(in_features=128, out_features=128, bias=False)
            (v_linear): Linear(in_features=128, out_features=128, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (layer_norm): LayerNorm()
        )
        (fc): PrePostProcessingWrapper(
          (module): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.1, inplace=Fa

In [69]:
print(metrics_df)

  model_name feature_set seq_len    pr_auc   roc_auc        f1  precision@5   
0    bitenet        dxtx       6  0.375744  0.591116  0.304348     0.418667  \
1    bitenet        dxtx       8  0.358171  0.574900  0.229167     0.000000   
2    bitenet        dxtx      10  0.357258  0.596695  0.296610     0.880222   
3    bitenet        dxtx      12  0.385917  0.622832  0.270492     0.961333   
4    bitenet        dxtx      14  0.264675  0.529203  0.184615     0.184778   

   precision@10  precision@15  precision@20  precision@25  precision@30  
0      0.636000      0.754667      0.828000      0.864000      0.882667  
1      0.000000      0.000000      0.000000      0.000000      0.000000  
2      0.860667      0.860000      0.849778      0.840222      0.834444  
3      0.964000      0.965333      1.000000      1.000000      1.000000  
4      0.239667      0.259444      0.272000      0.278556      0.284000  
