Google Cloud authentication and data download

In [1]:
! gcloud auth login
! gcloud auth application-default login
! gcloud config set project dl4h-final-project-383605

In [6]:
# ! pip install --upgrade google-api-python-client google-cloud-storage
from google.cloud import storage

# Replace these values with your project and bucket as needed
project_id = "dl4h-final-project-383605"
mimic3_bucket = "mimiciii-1.4.physionet.org"

storage_client = storage.Client(project=project_id)
bucket = storage_client.bucket(mimic3_bucket)
for blob in bucket.list_blobs():
  if "CHARTEVENTS" in blob.name:
    continue
  blob.download_to_filename(blob.name)

# Extract all the files
! gunzip *.gz

Library imports and data preparation

In [1]:
! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.data import Patient, Visit
import pandas as pd
from pyhealth.datasets import split_by_patient, get_dataloader, BaseDataset
from pyhealth.models import BaseModel
from pyhealth.trainer import Trainer
from pyhealth.metrics.binary import binary_metrics_fn
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Optional
from enum import Enum

# Set this to the directory with all MIMIC-3 dataset files
data_root = "/Users/cyg1122/Desktop/school/dl4h/mimic3/physionet.org/files/mimiciii/1.4/"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


  from tqdm.autonotebook import trange


In [2]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=False
)

In [3]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n'

In [26]:
# Find all diagnoses codes
# Find all procedure codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diag_codes = []
all_proc_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    procedures = visit.get_code_list(table="PROCEDURES_ICD")
    all_diag_codes.extend(conditions)
    all_proc_codes.extend(procedures)

codes = pd.Series(all_diag_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
num_unique_diag_codes = len(filtered_diag_codes)

unique_proc_codes = list(set(all_proc_codes))
num_unique_proc_codes = len(unique_proc_codes)

In [29]:
proc_code_to_index_map = {}
diag_code_to_index_map = {}

index = 0
for proc_code in unique_proc_codes:
    proc_code_to_index_map[proc_code] = index
    index += 1

index = 0
for diag_code in filtered_diag_codes:
    diag_code_to_index_map[diag_code] = index
    index += 1

def codes_to_multihot(codes, n_codes, code_to_index_map):
    multihot = np.zeros(n_codes)
    indices = [ code_to_index_map[code] for code in codes ]
    multihot[indices] = 1
    return multihot

def proc_code_to_onehot(codes):
    return codes_to_multihot(codes, num_unique_proc_codes, proc_code_to_index_map)

def diag_code_to_onehot(codes):
    return codes_to_multihot(codes, num_unique_diag_codes, diag_code_to_index_map)

In [21]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

# Readmission prediction for dataset consisting of diagnoses and procedure codes
def readmission_prediction_mimic3_fn(patient: Patient, time_window=30):
    if len(patient) < 2:
        return []

    samples = []

    # we will drop the last visit
    for i in range(len(patient) - 1):
        first_visit: Visit = patient[0]
        current_visit: Visit = patient[i]
        next_visit: Visit = patient[i + 1]

        # get time difference between current visit and next visit
        time_diff = (next_visit.encounter_time - current_visit.encounter_time).days
        time_diff_from_first_visit = (current_visit.encounter_time - first_visit.encounter_time).days
        readmission_label = 1 if time_diff < time_window else 0

        conditions = [c for c in current_visit.get_code_list(table="DIAGNOSES_ICD") if c in filtered_diag_codes]
        procedures = current_visit.get_code_list(table="PROCEDURES_ICD")
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(procedures) == 0:
            continue
        samples.append(
            {
                "visit_id": current_visit.visit_id,
                "patient_id": patient.patient_id,
                DIAGNOSES_KEY: conditions,
                PROCEDURES_KEY: procedures,
                INTERVAL_DAYS_KEY: [time_diff_from_first_visit],
                "label": readmission_label,
            }
        )
    # no cohort selection
    return samples

In [22]:
# Create the task datasets

mimic3_dxtx = mimic3_ds.set_task(task_fn=readmission_prediction_mimic3_fn)

Generating samples for readmission_prediction_mimic3_fn: 100%|██████████| 46520/46520 [00:10<00:00, 4265.44it/s]


In [33]:
print(type(mimic3_ds.patients))
print(mimic3_ds.patients.keys())

<class 'dict'>
dict_keys(['10', '100', '1000', '10000', '10001', '10002', '10003', '10004', '10005', '10006', '10007', '10008', '10009', '1001', '10010', '10011', '10012', '10013', '10014', '10015', '10016', '10017', '10019', '1002', '10020', '10021', '10022', '10023', '10024', '10025', '10026', '10027', '10028', '10029', '1003', '10030', '10032', '10033', '10034', '10035', '10036', '10037', '10038', '10039', '1004', '10040', '10041', '10042', '10043', '10044', '10045', '10046', '10047', '10048', '10049', '1005', '10050', '10051', '10052', '10054', '10055', '10056', '10057', '10058', '10059', '1006', '10060', '10061', '10062', '10063', '10064', '10065', '10066', '10067', '10068', '10069', '1007', '10071', '10072', '10073', '10074', '10075', '10076', '10077', '10079', '1008', '10080', '10081', '10082', '10083', '10084', '10085', '10086', '10087', '10088', '10089', '1009', '10090', '10091', '10092', '10093', '10094', '10096', '10097', '10098', '10099', '101', '1010', '10100', '10101', '1

In [23]:
# Print dataset statistics

print(mimic3_dxtx.stat())

Statistics of sample dataset:
	- Dataset: MIMIC3Dataset
	- Task: readmission_prediction_mimic3_fn
	- Number of samples: 10746
	- Number of patients: 6819
	- Number of visits: 10746
	- Number of visits per patient: 1.5759
	- conditions:
		- Number of conditions per sample: 13.4028
		- Number of unique conditions: 3327
		- Distribution of conditions (Top-10): [('4019', 3734), ('4280', 3649), ('42731', 2896), ('5849', 2316), ('41401', 2308), ('25000', 2015), ('51881', 1817), ('2724', 1676), ('5990', 1642), ('53081', 1417)]
	- procedures:
		- Number of procedures per sample: 4.3769
		- Number of unique procedures: 1331
		- Distribution of procedures (Top-10): [('3893', 3632), ('9604', 2008), ('966', 1935), ('9904', 1793), ('9671', 1711), ('3995', 1427), ('9672', 1322), ('3891', 983), ('9915', 841), ('8856', 838)]
	- days_since_first_visit:
		- Number of days_since_first_visit per sample: 1.0000
		- Length of days_since_first_visit: 1
	- label:
		- Number of label per sample: 1.0000
		- Num

In [24]:
# Create the dataloaders

BATCH_SIZE = 32
train, val, test = split_by_patient(mimic3_dxtx, [0.8, 0.1, 0.1])

# obtain train/val/test dataloader, they are <torch.data.DataLoader> object
train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

In [25]:
# Define the models

class MaskDirection(Enum):
    FORWARD = 'forward'
    BACKWARD = 'backward'
    DIAGONAL = 'diagonal'
    NONE = 'none'

# class MaskedLayerNorm(nn.Module):
#     def __init__(self, normalized_shape: int):
#         super().__init__()
#         self.gamma = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.beta = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.eps = 1e-5
#         self.normalized_shape = normalized_shape
#
#     def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
#         print(x.shape)
#         print(key_padding_mask.shape)
#         n = torch.sum(torch.ones_like(x) * (~key_padding_mask).int().unsqueeze(-1).expand(-1, -1, self.normalized_shape), dim=-1)
#         print(n.shape)
#         print(n)
#
#         # expected_val = torch.nanmean(x, dim=1)
#         # variance = torch.nansum(())

# class CodeEmbedding(nn.Module):
#     def __init__(
#             self,
#             num_embeddings: int,
#             embedding_dim: int
#     ):
#         self.emb = nn.Embedding()

# todo: add temporal encoding
# todo: implement layer normalization
# todo: implement fully connected network structure and linear layer activation functions
class MaskEnc(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            num_heads: int,
            dropout: float = 0.1,
            batch_first: bool = True,
            temporal_mask_direction: MaskDirection = MaskDirection.NONE
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.temporal_mask_direction = temporal_mask_direction

        self.attention = nn.MultiheadAttention(
            embed_dim=embedding_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=batch_first
        )

        # self.masked_layer_norm1 = MaskedLayerNorm(embedding_dim)
        # self.masked_layer_norm2 = MaskedLayerNorm(embedding_dim)

        self.fc = nn.Linear(embedding_dim, embedding_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(
            self,
            x: torch.Tensor,
            key_padding_mask: Optional[torch.Tensor] = None
    ):
        attn_mask = self._make_temporal_mask(x.shape[1])

        attn_output, attn_output_weights = self.attention(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        # attn_output = torch.nan_to_num(attn_output, nan=-1e9)


        # x = self.masked_layer_norm1(x + attn_output, key_padding_mask)
        # x = self.masked_layer_norm2(x + self.fc(x), key_padding_mask)

        x = self.dropout(self.relu(self.fc(attn_output)))
        return x

    def _make_temporal_mask(self, n: int) -> Optional[torch.Tensor]:
        if self.temporal_mask_direction == MaskDirection.NONE:
            return None

        mask = torch.ones(n,n)
        if self.temporal_mask_direction == MaskDirection.FORWARD:
            mask = torch.tril(mask)
        if self.temporal_mask_direction == MaskDirection.BACKWARD:
            mask = torch.triu(mask)
        if self.temporal_mask_direction == MaskDirection.DIAGONAL:
            mask = mask.fill_diagonal_(0)

        return mask.bool()

class BiteNet(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            output_dim: int,
            num_heads: int,
            dropout: float = 0.1,
            batch_first: bool = True
    ):
        super().__init__()

        # self.diag_emb = nn.Embedding(...)
        # self.proc_emb = nn.Embedding(...)

        def _make_mask_enc_block(temporal_mask_direction: MaskDirection = MaskDirection.NONE):
            return MaskEnc(
                embedding_dim = embedding_dim,
                num_heads = num_heads,
                dropout = dropout,
                batch_first = batch_first,
                temporal_mask_direction = temporal_mask_direction
            )

        self.code_level_attn = _make_mask_enc_block(MaskDirection.DIAGONAL)
        self.visit_level_attn_forward = _make_mask_enc_block(MaskDirection.FORWARD)
        self.visit_level_attn_backward = _make_mask_enc_block(MaskDirection.BACKWARD)
        self.fc = nn.Linear(2*embedding_dim, output_dim)

    def forward(
            self,
            x: torch.Tensor,
            key_padding_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        code_attn = self.code_level_attn(x, key_padding_mask)
        code_attn = torch.nan_to_num(code_attn)

        u_fw = self.visit_level_attn_forward(code_attn, key_padding_mask)
        u_fw = u_fw.nansum(dim=1).squeeze()

        u_bw = self.visit_level_attn_backward(code_attn, key_padding_mask)
        u_bw = u_bw.nansum(dim=1).squeeze()

        u_bi = torch.cat([u_fw, u_bw], dim=-1)
        s = self.fc(u_bi)
        return s

class PyHealthBiteNet(BaseModel):
    def __init__(
            self,
            dataset: BaseDataset,
            feature_keys: List[str],
            label_key: str,
            mode: str,
            use_interval_emb: bool = True,
            use_procedure_codes = True,
            embedding_dim: int = 128,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            **kwargs
    ):
        super().__init__(dataset, feature_keys, label_key, mode)

        # Any BaseModel should have these attributes, as functions like add_feature_transform_layer uses them
        self.feat_tokenizers = {}
        self.embeddings = nn.ModuleDict()
        self.linear_layers = nn.ModuleDict()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embedding_dim = embedding_dim

        # self.add_feature_transform_layer will create a transformation layer for each feature
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]
            self.add_feature_transform_layer(
                feature_key, input_info
            )

        # final output layer
        output_size = self.get_output_size(self.label_tokenizer)
        self.bite_net = BiteNet(
            embedding_dim = embedding_dim,
            output_dim = output_size,
            num_heads = num_heads,
            dropout = dropout,
            batch_first = batch_first
        )

        # self.fc = nn.Linear(len(self.feature_keys) * hidden_dim, output_size)

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
        patient_emb = []
        patient_mask = []
        for feature_key in self.feature_keys:

            feature_vals = kwargs[feature_key]
            x = self.feat_tokenizers[feature_key].batch_encode_2d(feature_vals, padding=True, truncation=False)
            x = torch.tensor(x, dtype=torch.long, device=self.device)
            pad_idx = self.feat_tokenizers[feature_key].vocabulary("<pad>")

            #create the mask
            mask = (x != pad_idx).bool()

            embeds = self.embeddings[feature_key](x)
            patient_emb.append(embeds)
            patient_mask.append(mask)

        # (patient, features * hidden_dim)
        patient_emb = torch.cat(patient_emb, dim=1)
        patient_mask = ~(torch.cat(patient_mask, dim=1))

        # (patient, label_size)
        logits = self.bite_net(patient_emb, patient_mask)
        # obtain y_true, loss, y_prob
        y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
        loss = self.get_loss_function()(logits, y_true)
        y_prob = self.prepare_y_prob(logits)
        return {"loss": loss, "y_prob": y_prob, "y_true": y_true}

model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions", "procedures"],
    label_key = "label",
    mode = "binary",
)

data = next(iter(train_loader))
model_dxtx(**data)

KeyboardInterrupt: 

In [18]:
model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions", "procedures"],
    label_key = "label",
    mode = "binary",
)

model_dx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions"],
    label_key = "label",
    mode = "binary",
)

In [19]:
data = next(iter(train_loader))
model_dx(**data)

{'loss': tensor(0.7363, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'y_prob': tensor([[0.4940],
         [0.5260],
         [0.5713],
         [0.6283],
         [0.4630],
         [0.4646],
         [0.5993],
         [0.5617],
         [0.5722],
         [0.5628],
         [0.5436],
         [0.4776],
         [0.6207],
         [0.5105],
         [0.6244],
         [0.4756],
         [0.5805],
         [0.5534],
         [0.6040],
         [0.5931],
         [0.5287],
         [0.5496],
         [0.5750],
         [0.4966],
         [0.5697],
         [0.4620],
         [0.5815],
         [0.5287],
         [0.6660],
         [0.5615],
         [0.5513],
         [0.5093]], grad_fn=<SigmoidBackward0>),
 'y_true': tensor([[0.],
         [0.],
         [0.],
         [0.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]

In [20]:
trainer_dxtx = Trainer(model=model_dxtx)
trainer_dxtx.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="pr_auc",
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (conditions): Embedding(3329, 128, padding_idx=0)
    (procedures): Embedding(1333, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (code_level_attn): MaskEnc(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (fc): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0, inplace=False)
    )
    (visit_level_attn_forward): MaskEnc(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (fc): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0, inplace=False)
    )
    (visit_level_attn_backward): MaskEnc(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, o

Epoch 0 / 5:   0%|          | 0/268 [00:00<?, ?it/s]

--- Train epoch-0, step-268 ---
loss: 0.6895
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 35/35 [00:00<00:00, 123.56it/s]
--- Eval epoch-0, step-268 ---
pr_auc: 0.5514
roc_auc: 0.5000
f1: 0.7109
loss: 0.6888
New best pr_auc score (0.5514) at epoch-0, step-268



Epoch 1 / 5:   0%|          | 0/268 [00:00<?, ?it/s]

--- Train epoch-1, step-536 ---
loss: 0.6861
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 35/35 [00:00<00:00, 122.29it/s]
--- Eval epoch-1, step-536 ---
pr_auc: 0.5514
roc_auc: 0.5000
f1: 0.7109
loss: 0.6882



Epoch 2 / 5:   0%|          | 0/268 [00:00<?, ?it/s]

--- Train epoch-2, step-804 ---
loss: 0.6853
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 35/35 [00:00<00:00, 122.83it/s]
--- Eval epoch-2, step-804 ---
pr_auc: 0.5514
roc_auc: 0.5000
f1: 0.7109
loss: 0.6883



Epoch 3 / 5:   0%|          | 0/268 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [66]:
# option 1: use our built-in evaluation metric
score_dxtx = trainer_dxtx.evaluate(test_loader)
print (score_dxtx)

# option 2: use our pyhealth.metrics to evaluate
y_true_dxtx, y_prob_dxtx, loss_dxtx = trainer_dxtx.inference(test_loader)
binary_metrics_fn(y_true_dxtx, y_prob_dxtx, metrics=["pr_auc"])

Evaluation: 100%|██████████| 33/33 [00:00<00:00, 367.20it/s]


{'pr_auc': 0.726784508178901, 'roc_auc': 0.685080036129632, 'f1': 0.7090909090909092, 'loss': 0.6331828191424861}


Evaluation: 100%|██████████| 33/33 [00:00<00:00, 363.51it/s]


{'pr_auc': 0.726784508178901}