Google Cloud authentication and data download

In [1]:
# ! gcloud auth login
# ! gcloud auth application-default login
# ! gcloud config set project dl4h-final-project-383605

In [3]:
# ! pip install --upgrade google-api-python-client google-cloud-storage
from google.cloud import storage

# Replace these values with your project and bucket as needed
project_id = "dl4h-final-project-383605"
mimic3_bucket = "mimiciii-1.4.physionet.org"

storage_client = storage.Client(project=project_id)
bucket = storage_client.bucket(mimic3_bucket)
data_folder = "./data"
for blob in bucket.list_blobs():
  if "CHARTEVENTS" in blob.name:
    continue
  blob.download_to_filename(f"{data_folder}/{blob.name}")

# Extract all the files
! gunzip {data_folder}/*.gz

gzip: *.gz: No such file or directory


In [5]:
! gunzip {data_folder}/*.gz

Library imports and data preparation

In [85]:
# ! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset, SampleDataset
from pyhealth.data import Visit
import pandas as pd
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import BaseModel
from pyhealth.trainer import Trainer
from pyhealth.metrics.binary import binary_metrics_fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional
from enum import Enum
from functools import reduce
from operator import mul

# Set this to the directory with all MIMIC-3 dataset files
data_root = "./data"

In [92]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=False
)

In [93]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n'

In [94]:
# Find all diagnoses codes
# Find all procedure codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diag_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    all_diag_codes.extend(conditions)

codes = pd.Series(all_diag_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
num_unique_diag_codes = len(filtered_diag_codes)

In [95]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window=30):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    # if the patient only has one visit, we drop it
    if len(patient) == 1:
        return []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # step 1: define label
    idx_last_visit = len(sorted_visits)-1
    last_visit: Visit = sorted_visits[idx_last_visit]
    second_to_last_visit: Visit = sorted_visits[idx_last_visit - 1]
    first_visit: Visit = sorted_visits[0]

    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff < time_window else 0

    # step 2: obtain features
    visits_conditions = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(sorted_visits):
        if idx == len(sorted_visits) - 1: break
        conditions = [c for c in visit.get_code_list(table="DIAGNOSES_ICD") if c in filtered_diag_codes]
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        if len(conditions) * len(procedures) == 0:
            continue

        visits_conditions.append(conditions)
        visits_procedures.append(procedures)
        visits_intervals.append([str(time_diff_from_first_visit)])

    unique_conditions = list(set(flatten(visits_conditions)))
    unique_procedures = list(set(flatten(visits_procedures)))

    # step 3: exclusion criteria
    if len(unique_conditions) * len(unique_procedures) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "conditions": visits_conditions,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "label": readmission_label,
        }
    )
    return samples

In [96]:
# Create the task datasets
mimic3_dxtx = mimic3_ds.set_task(task_fn=patient_level_readmission_prediction)

Generating samples for patient_level_readmission_prediction: 100%|██████████| 46520/46520 [00:10<00:00, 4367.74it/s]


In [100]:
BATCH_SIZE = 32
train, val, test = split_by_patient(mimic3_dxtx, [0.8, 0.1, 0.1])

train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

In [101]:
# Define the models
VERY_BIG_NUMBER = 1e30
VERY_SMALL_NUMBER = 1e-30
VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER
VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER

class MaskDirection(Enum):
    FORWARD = 'forward'
    BACKWARD = 'backward'
    DIAGONAL = 'diagonal'
    NONE = 'none'

# class MaskedLayerNorm(nn.Module):
#     def __init__(self, normalized_shape: int):
#         super().__init__()
#         self.gamma = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.beta = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.eps = 1e-5
#         self.normalized_shape = normalized_shape
#
#     def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
#         print(x.shape)
#         print(key_padding_mask.shape)
#         n = torch.sum(torch.ones_like(x) * (~key_padding_mask).int().unsqueeze(-1).expand(-1, -1, self.normalized_shape), dim=-1)
#         print(n.shape)
#         print(n)
#
#         # expected_val = torch.nanmean(x, dim=1)
#         # variance = torch.nansum(())

def flatten_keep(x: torch.Tensor, keep: int):
    fixed_shape = list(x.size())
    start = len(fixed_shape) - keep
    left = reduce(mul, [fixed_shape[i] or x.shape[i] for i in range(start)])
    out_shape = [left] + [fixed_shape[i] or x.shape[i] for i in range(start, len(fixed_shape))]
    return torch.reshape(x, out_shape)


def reshape(v, ref, embedding_dim):
    batch_size = ref.shape[0]
    n_visits = ref.shape[1]
    out = torch.reshape(v, [batch_size, n_visits, embedding_dim])
    return out

class AttentionPooling(nn.Module):
    def __init__(self, embedding_size: int):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, embedding_size)
        )

    def forward(self, input):
        x = self.fc(input)
        # x = torch.nan_to_num(x, nan=VERY_NEGATIVE_NUMBER)
        soft = F.softmax(x, dim=1) # Todo: fixme
        attn_output = torch.nansum(soft * input, 1)
        return attn_output


# todo: add temporal encoding
# todo: implement layer normalization
class MaskEnc(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            num_heads: int,
            dropout: float = 0.1,
            batch_first: bool = True,
            temporal_mask_direction: MaskDirection = MaskDirection.NONE,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.temporal_mask_direction = temporal_mask_direction

        self.attention = nn.MultiheadAttention(
                embed_dim=embedding_dim,
                num_heads=num_heads,
                dropout=dropout,
                batch_first=batch_first
            )
        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )

        # self.masked_layer_norm1 = MaskedLayerNorm(embedding_dim)
        # self.masked_layer_norm2 = MaskedLayerNorm(embedding_dim)

    def forward(
            self,
            x: torch.Tensor,
            key_padding_mask: Optional[torch.Tensor] = None
    ):
        attn_mask = self._make_temporal_mask(x.shape[1])

        attn_output, attn_output_weights = self.attention(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        attn_output = self.dropout(attn_output)
        out = self.fc(attn_output)

        # x = self.masked_layer_norm1(x + attn_output, key_padding_mask)
        # x = self.masked_layer_norm2(x + self.fc(x), key_padding_mask)
        return out

    def _make_temporal_mask(self, n: int) -> Optional[torch.Tensor]:
        if self.temporal_mask_direction == MaskDirection.NONE:
            return None

        if self.temporal_mask_direction == MaskDirection.FORWARD:
            return torch.tril(torch.ones(n,n))
        if self.temporal_mask_direction == MaskDirection.BACKWARD:
            return torch.triu(torch.ones(n,n))
        if self.temporal_mask_direction == MaskDirection.DIAGONAL:
            return torch.zeros(n,n).fill_diagonal_(1)


class BiteNet(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 128,
            output_dim: int = 1,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            n_mask_enc_layers: int = 2,
            use_procedures: bool = True,
            use_intervals: bool = True,
    ):
        super().__init__()

        self.use_intervals = use_intervals
        self.use_procedures = use_procedures
        self.embedding_dim = embedding_dim

        def _make_mask_enc_block(temporal_mask_direction: MaskDirection = MaskDirection.NONE):
            return MaskEnc(
                embedding_dim = embedding_dim,
                num_heads = num_heads,
                dropout = dropout,
                batch_first = batch_first,
                temporal_mask_direction = temporal_mask_direction,
            )

        self.code_level_attn = nn.ModuleList()
        self.visit_level_attn_forward = nn.ModuleList()
        self.visit_level_attn_backward = nn.ModuleList()
        for _ in range(n_mask_enc_layers):
            self.code_level_attn.append(_make_mask_enc_block(MaskDirection.DIAGONAL))
            self.visit_level_attn_forward.append(_make_mask_enc_block(MaskDirection.FORWARD))
            self.visit_level_attn_backward.append(_make_mask_enc_block(MaskDirection.BACKWARD))
        self.code_level_attn.append(AttentionPooling(embedding_dim))
        self.visit_level_attn_forward.append(AttentionPooling(embedding_dim))
        self.visit_level_attn_backward.append(AttentionPooling(embedding_dim))

        self.fc = nn.Sequential(
            nn.Linear(2*embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, output_dim)
        )

    def forward(
            self,
            embedded_codes: torch.Tensor,
            embedded_intervals: torch.Tensor,
            codes_mask: torch.Tensor,
            visits_mask: torch.Tensor,
    ) -> torch.Tensor:

        codes_mask = ~(codes_mask.bool())

        # input tensor, reshape 4 dimension to 3
        flattened_codes = flatten_keep(embedded_codes, 2)

        # input mask, reshape 3 dimension to 2
        flattened_codes_mask = flatten_keep(codes_mask, 1)

        code_attn = flattened_codes
        for i, l in enumerate(self.code_level_attn):
            if i == len(self.code_level_attn) - 1:
                code_attn = l(code_attn)
            else:
                code_attn = l(code_attn, flattened_codes_mask)
        code_attn = reshape(code_attn, embedded_codes, self.embedding_dim)

        if self.use_intervals:
            code_attn += embedded_intervals

        visits_mask = ~(visits_mask.bool())

        u_fw = code_attn
        for i, l in enumerate(self.visit_level_attn_forward):
            if i == len(self.visit_level_attn_forward) - 1:
                u_fw = l(u_fw)
            else:
                u_fw = l(u_fw, visits_mask)

        u_bw = code_attn
        for i, l in enumerate(self.visit_level_attn_backward):
            if i == len(self.visit_level_attn_backward) - 1:
                u_bw = l(u_bw)
            else:
                u_bw = l(u_bw, visits_mask)

        u_bi = torch.cat([u_fw, u_bw], dim=-1)
        s = self.fc(u_bi)
        return s

class PyHealthBiteNet(BaseModel):
    def __init__(
            self,
            dataset: SampleDataset,
            feature_keys: List[str],
            label_key: str,
            mode: str,
            embedding_dim: int = 128,
            n_mask_enc_layers: int = 2,
            use_intervals: bool = True,
            use_procedures: bool = True,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            **kwargs
    ):
        super().__init__(dataset, feature_keys, label_key, mode)

        self.use_intervals = use_intervals
        self.use_procedures = use_procedures

        # Any BaseModel should have these attributes, as functions like add_feature_transform_layer uses them
        self.feat_tokenizers = {}
        self.embeddings = nn.ModuleDict()
        self.linear_layers = nn.ModuleDict()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embedding_dim = embedding_dim

        # self.add_feature_transform_layer will create a transformation layer for each feature
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]
            self.add_feature_transform_layer(
                feature_key, input_info, special_tokens=["<pad>", "<unk>"]
            )

        # final output layer
        output_size = self.get_output_size(self.label_tokenizer)
        self.bite_net = BiteNet(
            embedding_dim = embedding_dim,
            output_dim = output_size,
            num_heads = num_heads,
            dropout = dropout,
            batch_first = batch_first,
            use_intervals=use_intervals,
            use_procedures=use_procedures,
            n_mask_enc_layers=n_mask_enc_layers
        )

        # self.fc = nn.Linear(len(self.feature_keys) * hidden_dim, output_size)

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:

        embeddings = {}
        masks = {}
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]

            # each patient's feature is represented by [[code1, code2],[code3]]
            assert input_info["dim"] == 3 and input_info["type"] == str
            feature_vals = kwargs[feature_key]

            x = self.feat_tokenizers[feature_key].batch_encode_3d(feature_vals, truncation=(False, False))
            x = torch.tensor(x, dtype=torch.long, device=self.device)
            pad_idx = self.feat_tokenizers[feature_key].vocabulary("<pad>")
            #create the mask
            mask = (x != pad_idx).long()
            embeds = self.embeddings[feature_key](x)
            embeddings[feature_key] = embeds
            masks[feature_key] = mask

        embedded_codes = embeddings['conditions']
        codes_mask = masks['conditions']
        if self.use_procedures:
            embedded_codes = torch.cat((embedded_codes, embeddings['procedures']), dim=2)
            codes_mask = torch.cat((codes_mask, masks['procedures']), dim=2)

        logits = self.bite_net(embedded_codes, embeddings['intervals'].squeeze(2), codes_mask, masks['intervals'].squeeze(-1))

        # obtain y_true, loss, y_prob
        y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
        loss = self.get_loss_function()(logits, y_true)
        y_prob = self.prepare_y_prob(logits)
        return {"loss": loss, "y_prob": y_prob, "y_true": y_true}

model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ['procedures', 'conditions', 'intervals'],
    label_key = "label",
    mode = "binary",
    embedding_dim=128
)

data = next(iter(train_loader))
model_dxtx(**data)

# model = BiteNet(
#     embedding_dim = 4,
#     output_dim = 1,
#     num_heads = 4,
#     dropout = 0
# )
# procedures, conditions, intervals, visits_masks, procedures_masks, conditions_masks, labels = next(iter(train_loader))
# model(procedures, conditions, intervals, visits_masks, procedures_masks, conditions_masks)

{'loss': tensor(0.7038, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'y_prob': tensor([[0.5085],
         [0.5092],
         [0.5083],
         [0.5091],
         [0.5097],
         [0.5091],
         [0.5092],
         [0.5086],
         [0.5096],
         [0.5093],
         [0.5093],
         [0.5086],
         [0.5096],
         [0.5092],
         [0.5096],
         [0.5095],
         [0.5089],
         [0.5086],
         [0.5096],
         [0.5095],
         [0.5085],
         [0.5091],
         [0.5090],
         [0.5089],
         [0.5085],
         [0.5093],
         [0.5091],
         [0.5089],
         [0.5089],
         [0.5096],
         [0.5087],
         [0.5094]], grad_fn=<SigmoidBackward0>),
 'y_true': tensor([[1.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [1.],
         [0.],
         [0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [1.]

In [102]:
model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions", "procedures", "intervals"],
    label_key = "label",
    mode = "binary",
)

model_dx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions", "intervals"],
    label_key = "label",
    mode = "binary",
)

In [None]:
trainer_dxtx = Trainer(model=model_dxtx)
trainer_dxtx.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="pr_auc",
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (conditions): Embedding(3383, 128, padding_idx=0)
    (procedures): Embedding(1366, 128, padding_idx=0)
    (intervals): Embedding(1758, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (code_level_attn): ModuleList(
      (0-1): 2 x MaskEnc(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=128, out_features=128, bias=True)
        )
      )
      (2): AttentionPooling(
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=128, bias=True)
        )
    

Epoch 0 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-0, step-175 ---
loss: 0.5208
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 26.03it/s]
--- Eval epoch-0, step-175 ---
pr_auc: 0.2080
roc_auc: 0.5000
f1: 0.0000
loss: 0.6445
New best pr_auc score (0.2080) at epoch-0, step-175



Epoch 1 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-1, step-350 ---
loss: 0.5159
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 27.59it/s]
--- Eval epoch-1, step-350 ---
pr_auc: 0.2080
roc_auc: 0.5000
f1: 0.0000
loss: 0.6390



Epoch 2 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-2, step-525 ---
loss: 0.5163
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 26.34it/s]
--- Eval epoch-2, step-525 ---
pr_auc: 0.2080
roc_auc: 0.5000
f1: 0.0000
loss: 0.6276



Epoch 3 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-3, step-700 ---
loss: 0.5130
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 27.41it/s]
--- Eval epoch-3, step-700 ---
pr_auc: 0.2080
roc_auc: 0.5000
f1: 0.0000
loss: 0.6228



Epoch 4 / 5:   0%|          | 0/175 [00:00<?, ?it/s]



In [190]:
# option 1: use our built-in evaluation metric
score_dxtx = trainer_dxtx.evaluate(test_loader)
print (score_dxtx)

# option 2: use our pyhealth.metrics to evaluate
y_true_dxtx, y_prob_dxtx, loss_dxtx = trainer_dxtx.inference(test_loader)
binary_metrics_fn(y_true_dxtx, y_prob_dxtx, metrics=["pr_auc"])

  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 27.97it/s]


{'pr_auc': 0.20373027259684362, 'roc_auc': 0.5, 'f1': 0.0, 'loss': 0.6519408307292245}


  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 28.00it/s]


{'pr_auc': 0.20373027259684362}