Library imports and data preparation

In [1]:
from pyhealth.data import Visit
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.medcode import InnerMap
from pyhealth.metrics.binary import binary_metrics_fn
from sklearn.metrics import precision_score
import numpy as np
import pandas as pd
import torch
from typing import List, Dict
from model.bitenet import BiteNet
from model.baseline import RNN, BRNN, RETAIN, Deepr
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RANDOM_SEED = 0
BATCH_SIZE = 32
KS = list(range(5, 31, 5))
SEQ_LENS = list(range(6, 17, 2))

  from tqdm.autonotebook import trange


In [2]:
# Load the dataset
with open("mimic3_dataset.pkl", "rb") as dataset_file:
    mimic3_ds = pickle.load(dataset_file)

In [3]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

icd9cm = InnerMap.load("ICD9CM")

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window: int = 30, max_length_visits: int = None):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # Clip the patient visits to the most recent max_length_visits + 1 if max_length_visits is not None
    if max_length_visits is not None:
        n_visits = len(sorted_visits)
        if n_visits > max_length_visits + 1:
            sorted_visits = sorted_visits[n_visits - (max_length_visits + 1):]

    feature_visits: List[Visit] = sorted_visits[:-1]
    last_visit: Visit = sorted_visits[-1]
    second_to_last_visit: Visit = feature_visits[-1]
    first_visit: Visit = feature_visits[0]

    # step 1 a: define readmission label
    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff <= time_window else 0

    # step 1 b: define diagnosis prediction label
    diagnosis_label = list(set([icd9cm.get_ancestors(code)[1] for code in last_visit.get_code_list("DIAGNOSES_ICD")]))

    # step 2: obtain features
    visits_diagnoses = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(feature_visits):
        diagnoses = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        # Exclude visits that are missing either diagnoses or procedures.
        # BiteNet can handle missing procedures, but other PyHealth models like RNN
        # require all features have a length greater than 0.
        if len(diagnoses) == 0:
            continue

        visits_diagnoses.append(diagnoses)
        visits_procedures.append(procedures)
        visits_intervals.append([str(time_diff_from_first_visit)])

    unique_diagnoses = list(set(flatten(visits_diagnoses)))

    # step 3: exclusion criteria
    if len(unique_diagnoses) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "diagnoses": visits_diagnoses,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "readmission_label": readmission_label,
            "diagnosis_label": diagnosis_label
        }
    )
    return samples

In [4]:
RESULTS_FILE = "baseline_comparison.csv"

In [5]:
from sklearn.metrics import confusion_matrix
def train_and_inference(model, train_loader, val_loader, test_loader, lr=0.0001, monitor="pr_auc", optim = torch.optim.Adam):
    trainer = Trainer(model=model, device=device)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=10,
        monitor=monitor,
        optimizer_class=optim,
        optimizer_params = {"lr" : lr},
        load_best_model_at_last=False
    )

    return trainer.inference(test_loader)

def precision_at_k(y_true: np.ndarray, y_prob: np.ndarray):

    y_pred: np.ndarray = (y_prob > 0.5).astype(int)
    desc_idx: np.ndarray = np.flip(np.argsort(y_prob, axis=-1), axis=-1)

    y_true = np.take(y_true, desc_idx).astype(int)
    y_pred = np.take(y_pred, desc_idx)

    precisions: List[float] = []
    for k in KS:
        precisions.append(
            precision_score(y_true[:, :k].reshape(-1), y_pred[:, :k].reshape(-1))
        )

    precisions: Dict[str, float] = {
        f"precision@{k}": p for k, p in zip(KS, precisions)
    }
    return precisions

In [6]:
def train_and_record_metrics(model_readm, model_diag, df, row_fields, train_loader, val_loader, test_loader, lr=0.0005, readm_optim=torch.optim.Adam, diag_optim=torch.optim.Adam):
    y_true, y_prob, _ = train_and_inference(
        model_readm,
        train_loader,
        val_loader,
        test_loader,
        lr=lr,
        optim=readm_optim
    )
    binary_metrics = binary_metrics_fn(y_true, y_prob, metrics=["pr_auc", "roc_auc", "f1"])

    y_true, y_prob, _ = train_and_inference(
        model_diag,
        train_loader,
        val_loader,
        test_loader,
        lr=lr,
        monitor="pr_auc_samples",
        optim=diag_optim
    )
    precisions = precision_at_k(y_true, y_prob)

    row = binary_metrics | precisions | row_fields
    row = {
        k: [v] for k, v in row.items()
    }

    df = pd.concat([df, pd.DataFrame.from_dict(row)], ignore_index=True)

    # Save df for checkpoint
    df.to_csv(RESULTS_FILE, index=False)

    return df

In [None]:
%%capture

# Compare BiteNet performance to baselines

metrics_df = pd.DataFrame(columns=['model_name', 'feature_set', 'seq_len', 'pr_auc', 'roc_auc', 'f1'] + [f"precision@{k}" for k in KS])

for seq_len in SEQ_LENS:

    dataset = mimic3_ds.set_task(
        task_fn=lambda p: patient_level_readmission_prediction(p, max_length_visits=seq_len)
    )

    train, val, test = split_by_patient(dataset, [0.8, 0.1, 0.1], seed=RANDOM_SEED)

    train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

    #################### BITENET ####################
    metrics_df = train_and_record_metrics(
        model_readm=BiteNet(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures", "intervals"],
            label_key = "readmission_label",
            mode = "binary",
        ).to(device),
        model_diag=BiteNet(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures", "intervals"],
            label_key = "diagnosis_label",
            mode = "multilabel",
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "bitenet",
            "feature_set": "dxtx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
    )

    metrics_df = train_and_record_metrics(
        model_readm=BiteNet(
            dataset = dataset,
            feature_keys = ["diagnoses", "intervals"],
            label_key = "readmission_label",
            mode = "binary",
        ).to(device),
        model_diag=BiteNet(
            dataset = dataset,
            feature_keys = ["diagnoses", "intervals"],
            label_key = "diagnosis_label",
            mode = "multilabel",
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "bitenet",
            "feature_set": "dx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

    #################### RNN ####################
    metrics_df = train_and_record_metrics(
            model_readm=RNN(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "readmission_label",
            mode = "binary",
        ).to(device),
        model_diag=RNN(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "diagnosis_label",
            mode = "multilabel",
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "rnn",
            "feature_set": "dxtx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

    metrics_df = train_and_record_metrics(
            model_readm=RNN(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "readmission_label",
            mode = "binary",
        ).to(device),
        model_diag=RNN(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "diagnosis_label",
            mode = "multilabel",
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "rnn",
            "feature_set": "dx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
    )

    #################### BRNN ####################
    metrics_df = train_and_record_metrics(
            model_readm=BRNN(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "readmission_label",
            mode = "binary",
            bidirectional=True
        ).to(device),
        model_diag=RNN(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "diagnosis_label",
            mode = "multilabel",
            bidirectional=True
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "brnn",
            "feature_set": "dxtx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

    metrics_df = train_and_record_metrics(
            model_readm=BRNN(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "readmission_label",
            mode = "binary",
            bidirectional=True
        ).to(device),
        model_diag=RNN(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "diagnosis_label",
            mode = "multilabel",
            bidirectional=True
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "brnn",
            "feature_set": "dx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

    #################### RETAIN ####################
    metrics_df = train_and_record_metrics(
            model_readm=RETAIN(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "readmission_label",
            mode = "binary"
        ).to(device),
        model_diag=RETAIN(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "diagnosis_label",
            mode = "multilabel"
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "retain",
            "feature_set": "dxtx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
    )

    metrics_df = train_and_record_metrics(
            model_readm=RETAIN(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "readmission_label",
            mode = "binary"
        ).to(device),
        model_diag=RETAIN(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "diagnosis_label",
            mode = "multilabel"
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "retain",
            "feature_set": "dx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

    #################### Deepr ####################
    metrics_df = train_and_record_metrics(
            model_readm=Deepr(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "readmission_label",
            mode = "binary"
        ).to(device),
        model_diag=Deepr(
            dataset = dataset,
            feature_keys = ["diagnoses", "procedures"],
            label_key = "diagnosis_label",
            mode = "multilabel"
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "deepr",
            "feature_set": "dxtx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

    metrics_df = train_and_record_metrics(
            model_readm=Deepr(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "readmission_label",
            mode = "binary"
        ).to(device),
        model_diag=Deepr(
            dataset = dataset,
            feature_keys = ["diagnoses"],
            label_key = "diagnosis_label",
            mode = "multilabel"
        ).to(device),
        df=metrics_df,
        row_fields={
            "model_name": "deepr",
            "feature_set": "dx",
            "seq_len": seq_len,
        },
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader
    )

BiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3428, 128, padding_idx=0)
    (procedures): Embedding(1358, 128, padding_idx=0)
    (intervals): Embedding(1649, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): _BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): PrePostProcessingWrapper(
          (module): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=False)
            (k_linear): Linear(in_features=128, out_features=128, bias=False)
            (v_linear): Linear(in_features=128, out_features=128, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (layer_norm): LayerNorm()
        )
        (fc): PrePostProcessingWrapper(
          (module): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.1, inplace=Fa

In [None]:
%%capture

# Evaluate BiteNet performance as number of MaskEnc layers changes

RESULTS_FILE = "./results/changing_n_layers.csv"
n_layers_df = pd.DataFrame(columns=['model_name', 'feature_set', 'seq_len', 'n_layers', 'trial', 'pr_auc', 'roc_auc', 'f1'] + [f"precision@{k}" for k in KS])

for seq_len in SEQ_LENS:

    dataset = mimic3_ds.set_task(
        task_fn=lambda p: patient_level_readmission_prediction(p, max_length_visits=seq_len)
    )

    for n_layers in range(9):

        train, val, test = split_by_patient(dataset, [0.8, 0.1, 0.1], seed=RANDOM_SEED)

        train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

        #################### BITENET ####################
        n_layers_df = train_and_record_metrics(
            model_readm=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures", "intervals"],
                label_key = "readmission_label",
                mode = "binary",
                n_mask_enc_layers=n_layers
            ).to(device),
            model_diag=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures", "intervals"],
                label_key = "diagnosis_label",
                mode = "multilabel",
                n_mask_enc_layers=n_layers
            ).to(device),
            df=n_layers_df,
            row_fields={
                "model_name": "bitenet",
                "feature_set": "dxtx",
                "seq_len": seq_len,
                "n_layers": n_layers,
            },
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
        )

        n_layers_df = train_and_record_metrics(
            model_readm=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "intervals"],
                label_key = "readmission_label",
                mode = "binary",
                n_mask_enc_layers=n_layers
            ).to(device),
            model_diag=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "intervals"],
                label_key = "diagnosis_label",
                mode = "multilabel",
                n_mask_enc_layers=n_layers
            ).to(device),
            df=n_layers_df,
            row_fields={
                "model_name": "bitenet",
                "feature_set": "dx",
                "seq_len": seq_len,
                "n_layers": n_layers,
            },
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader
        )

In [None]:
%%capture

# Evaluate BiteNet performance as number of attention heads in MaskEnc layers changes

RESULTS_FILE = "./results/changing_n_heads.csv"
n_heads_df = pd.DataFrame(columns=['model_name', 'feature_set', 'seq_len', 'n_heads', 'trial', 'pr_auc', 'roc_auc', 'f1'] + [f"precision@{k}" for k in KS])

for seq_len in SEQ_LENS:

    dataset = mimic3_ds.set_task(
        task_fn=lambda p: patient_level_readmission_prediction(p, max_length_visits=seq_len)
    )

    for n_heads in [4, 8, 16, 32]:

        train, val, test = split_by_patient(dataset, [0.8, 0.1, 0.1], seed=RANDOM_SEED)

        train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

        #################### BITENET ####################
        n_heads_df = train_and_record_metrics(
            model_readm=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures", "intervals"],
                label_key = "readmission_label",
                mode = "binary",
                n_heads=n_heads
            ).to(device),
            model_diag=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "procedures", "intervals"],
                label_key = "diagnosis_label",
                mode = "multilabel",
                n_heads=n_heads
            ).to(device),
            df=n_heads_df,
            row_fields={
                "model_name": "bitenet",
                "feature_set": "dxtx",
                "seq_len": seq_len,
                "n_heads": n_heads,
            },
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
        )

        n_heads_df = train_and_record_metrics(
            model_readm=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "intervals"],
                label_key = "readmission_label",
                mode = "binary",
                n_heads=n_heads
            ).to(device),
            model_diag=BiteNet(
                dataset = dataset,
                feature_keys = ["diagnoses", "intervals"],
                label_key = "diagnosis_label",
                mode = "multilabel",
                n_heads=n_heads
            ).to(device),
            df=n_heads_df,
            row_fields={
                "model_name": "bitenet",
                "feature_set": "dx",
                "seq_len": seq_len,
                "n_heads": n_heads,
            },
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader
        )

In [None]:
# Ablations