### This is the Primary Notebook for the Project
If you have not run the preprocessing notebook, please do so before running this notebook. This notebook will perform the following tasks:
- Load the preprocessed data
- Define the Mimic3Dataset class to create a PyTorch Dataset
- Define the DST model

#### Mount Google Drive and change directory to your project folder
Yours will be different than what is listed below

In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/dl4h_project/DynST/
%ls

Mounted at /content/drive
/content/drive/MyDrive/dl4h_project/DynST
causal.ipynb            [0m[01;34moutputs[0m/                            README.md
config.yaml             poetry.lock                         README_new.md
coxph_model.ipynb       project_causal.ipynb                results20230416.txt
[01;34mdata[0m/                   project_coxph_model.ipynb           run.py
dl4hProjectSetup.ipynb  project_descriptive_notebook.ipynb  [01;34msrc[0m/
[01;34mlightning_logs[0m/         project_model.ipynb                 [01;34mwandb[0m/
[01;34mmimic3-survival[0m/        project_preprocess.ipynb
[01;34mmultirun[0m/               pyproject.toml


#### Install some needed packages
If you have not installed these in your current runtime before. 
You will need to restart the runtime after installation (Runtime > Restart Runtime)

In [2]:
# Install torchmetrics
%pip install torchmetrics -q
# Install PyTorch Lightning
%pip install pytorch-lightning -q
# Install wandb
%pip install wandb -qU

## now restart the runtime ##


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m719.0/719.0 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m48.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.6/149.6 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py

 #### Change back to directory for the project folder

In [1]:
%cd /content/drive/MyDrive/dl4h_project/DynST/
%ls

/content/drive/MyDrive/dl4h_project/DynST
causal.ipynb            [0m[01;34moutputs[0m/                            README.md
config.yaml             poetry.lock                         README_new.md
coxph_model.ipynb       project_causal.ipynb                results20230416.txt
[01;34mdata[0m/                   project_coxph_model.ipynb           run.py
dl4hProjectSetup.ipynb  project_descriptive_notebook.ipynb  [01;34msrc[0m/
[01;34mlightning_logs[0m/         project_model.ipynb                 [01;34mwandb[0m/
[01;34mmimic3-survival[0m/        project_preprocess.ipynb
[01;34mmultirun[0m/               pyproject.toml


### Mimic3Dataset Class
This will be called later to create the dataset for training and testing the model. It will read the csv file created by the Mimic3Pipeline class, and create the dataset for training and testing the model.
Also provided are the functions to pad and collate the data for use by the model.

In [2]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

In [3]:
class Mimic3Dataset(Dataset):
    def __init__(self, work_dir, seed, intervention=None):
        fdir = f"{work_dir}/data/preprocessed_{seed}"
        self.f = {}
        for fname in os.listdir(fdir):
            if fname.endswith(".npy"):
                self.f[fname[:-4]] = np.load(
                    f"{fdir}/{fname}", allow_pickle=True
                    )
        self.ix = self.f["patient_index"]
        self.code_lookup = np.insert(self.f["code_lookup"], 0, "pad")
        self.codes = self.f["codes"] + 1
        self.n_codes = len(self.code_lookup)
        self.n_vitals = self.f["vitals"].shape[1]
        self.n_demog = self.f["demog"].shape[1]
        self.pad_value = - 100
        # if supplied, represents treatment (True) or control (False)
        self.intervention = intervention



    def __len__(self):
        return len(self.f["treatment"])

    def __getitem__(self, index):
        item = {}
        j = self.ix[index]
        if self.intervention is None:
            item["treatment"] = self.f["treatment"][index]
        else:
            item["treatment"] = int(self.intervention)
        item["demog"] = self.f["demog"][index]
        item["codes"] = torch.tensor(
            self.pad_bincount(self.f["codes"][self.f["code_index"] == j])
        )
        item["vitals"] = torch.tensor(
            self.f["vitals"][self.f["hourly_index"] == j]
        ).float()
        item["survival"] = torch.tensor(
            self.f["survival"][self.f["hourly_index"] == j]
        )
        return item

    def pad_bincount(self, records):
        # get counts of each cod
        records = np.bincount(records)
        # pad each vector to length T, all possible codes
        padded = np.zeros(self.n_codes)
        padded[: len(records)] = records
        return torch.from_numpy(padded).float()

In [4]:
def padded_collate(batch, pad_index, causal=False):
    res = {}
    treatment = torch.tensor(np.array([d["treatment"] for d in batch]))
    demog = torch.tensor(np.array([d["demog"] for d in batch])).float()
    if causal:
        res["treatment"] = torch.tensor(np.array([d["treatment"] for d in batch]))
        res["static"] = torch.tensor(np.array([d["demog"] for d in batch])).float()
    else:
        res["static"] = torch.cat([demog, treatment.unsqueeze(1)], 1)
    res["codes"] = torch.stack([d["codes"] for d in batch])
    res["vitals"] = pad_sequence(
        [d["vitals"] for d in batch], batch_first=True, padding_value=pad_index
    )
    res["survival"] = pad_sequence(
        [d["survival"] for d in batch], batch_first=True, padding_value=pad_index
    )
    return res

### Metrics
Here we define the metrics to be used for evaluating the model. We will use Concordance Index (C-index) and Mean Absolute Error (MAE).

In [5]:
from torchmetrics import Metric


class MeanAbsoluteError(Metric):
    full_state_update = True
    higher_is_better = False
    def __init__(self, pad):
        super().__init__()
        self.pad = pad
        self.add_state("error", default=torch.tensor(0.))
        self.add_state("total", default=torch.tensor(0.))

    def update(self, s_hat, y):
        observed = (y == 0).any(1).int()
        t_hat = torch.where(y == self.pad, 0, s_hat).sum(1)
        t = torch.where(y == self.pad, 0, y).sum(1)
        zeros = torch.zeros(t.shape).cuda()
        observed_error = torch.abs(t_hat - t) * observed
        censored_error = torch.maximum(zeros, t - t_hat) * (1 - observed)
        self.error +=  observed_error.sum() + censored_error.sum()
        self.total += t.numel()

    def compute(self):
        return self.error / self.total

class ConcordanceIndex(Metric):
    higher_is_better = True
    def __init__(self, pad):
        super().__init__()
        self.pad = pad
        self.add_state("observed", default=[])
        self.add_state("true", default=[])
        self.add_state("predicted", default=[])

    def update(self, s_hat, y):
        self.observed.append((y == 0).any(1).int())
        self.true.append(torch.where(y == self.pad, 0, y).sum(1))
        self.predicted.append(torch.where(y == self.pad, 0, s_hat).sum(1))

    def compute(self):
        assert len(self.true) > 1
        # get pairs of elements
        t = torch.combinations(torch.cat(self.true))
        t_hat = torch.combinations(torch.cat(self.predicted))
        d = torch.combinations(torch.cat(self.observed))
        num = (
            (t[:, 0] < t[:, 1]) * (t_hat[:, 0] < t_hat[:, 1]) *\
                d[:, 0]
        ).sum()
        denom = (
            (t[:, 0] < t[:, 1]) * d[:, 0]
        ).sum()
        return num / denom


### Dynamic Survival Transformers Model Class

#### First some housekeeping

In [6]:
# Import the os module
import os

# Get the current working directory
cwd = os.getcwd()

# Print the current working directory
print("Current working directory: {0}".format(cwd))

# Install modules and libraries
import pytorch_lightning as pl
import torch
import math
from torch.nn import Linear


Current working directory: /content/drive/MyDrive/dl4h_project/DynST


#### DST Model Class

In [7]:
class DST(pl.LightningModule):
    def __init__(
        self,
        n_codes,
        n_vitals,
        n_demog,
        d_model,
        n_blocks,
        n_heads,
        dropout,
        pad,
        dynamic,
        causal,
        lr=0.001,
        alpha=0.01,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.embed_codes = Linear(n_codes, d_model)
        d1 = 0 if causal else 1
        self.embed_static = Linear(d_model + n_demog + d1, d_model)
        self.embed_vitals = Linear(n_vitals, d_model)
        self.pos_encode = PositionalEncoding(d_model)
        self.pad = pad
        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model = d_model,
            nhead=n_heads,
            dropout=dropout,
            batch_first=True,
            dim_feedforward=d_model*4
        )
        norm = torch.nn.LayerNorm(d_model)
        self.transformer = torch.nn.TransformerEncoder(encoder_layer, n_blocks, norm)
        d2 = 1 if causal else 0
        self.to_hazard_c = torch.nn.Sequential(
            Linear(d_model + d2, d_model//2),
            torch.nn.ReLU(),
            Linear(d_model//2, 1),
            torch.nn.Sigmoid(),
        )
        self.train_mae = MeanAbsoluteError(pad=pad)
        self.val_mae = MeanAbsoluteError(pad=pad)
        self.val_ci = ConcordanceIndex(pad=pad)
        self.test_mae = MeanAbsoluteError(pad=pad)
        self.test_ci = ConcordanceIndex(pad=pad)

        # how much to weigh MAE loss
        self.alpha = alpha
        self.dynamic = dynamic
        self.causal = causal


    def forward(self, batch):
        # static features
        x = self.embed_codes(batch["codes"]).unsqueeze(1)
        x = self.embed_static(
            torch.cat([x, batch["static"].unsqueeze(1)], 2)
        )
        s = batch["vitals"].shape[1]
        # time-varying features
        if self.dynamic:
            pad_mask = (batch["vitals"][:, :, 0] == self.pad)
            x = x + self.embed_vitals(batch["vitals"])
            # autoregressive mask
            mask = (1 - torch.tril(torch.ones(s, s))).bool().cuda()

        else:
            mask = None
            x = x.repeat(1, s, 1)
            pad_mask = (batch["vitals"][:, :, 0] == self.pad)
        x = self.pos_encode(x)
        x = self.transformer(x, mask, pad_mask)
        if self.causal:
            t = torch.reshape(batch["treatment"], (-1, 1, 1))
            t = t.repeat(1, s, 1)
            # concatenate treatment as a new feature
            x = torch.cat((x, t), 2).float()
        # complement of hazard
        q_hat = self.to_hazard_c(x).squeeze(2)
        s_hat = q_hat.cumprod(1).clamp(min=1e-8)
        return s_hat

    def training_step(self, batch, batch_idx):
        s_hat =  self(batch)
        loss = self.combined_loss(s_hat, batch["survival"])
        self.log("train_loss", loss)
        self.train_mae(s_hat, batch["survival"])
        self.log("train_mae", self.train_mae, on_step=True, on_epoch=False)
        return loss



    def validation_step(self, batch, batch_idx):
        s_hat =  self(batch)
        loss = self.combined_loss(s_hat, batch["survival"])
        self.val_mae.update(s_hat, batch["survival"])
        self.val_ci.update(s_hat, batch["survival"])
        self.log("val_loss", loss)
        self.log("val_mae", self.val_mae, on_step=True, on_epoch=True)
        self.log("val_ci", self.val_ci, on_step=True, on_epoch=True)
        return loss


    def test_step(self, batch, batch_idx):
        s_hat =  self(batch)
        loss = self.combined_loss(s_hat, batch["survival"])
        self.test_mae.update(s_hat, batch["survival"])
        self.test_ci.update(s_hat, batch["survival"])
        self.log("test_mae", self.test_mae, on_step=True, on_epoch=True)
        self.log("test_ci", self.test_ci, on_step=True, on_epoch=True)
        return loss

    def predict_step(self, batch, batch_idx):
        # returns estimated survival times
        s_hat = self(batch)
        mask = (batch["survival"] != self.pad)
        return (s_hat * mask).sum(1)


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def ordinal_survival_loss(self, s_hat, y):
        # modified cross entropy loss
        nlog_survival = -torch.log(s_hat)
        nlog_failure = -torch.log(1 - s_hat)
        loss = 0
        loss += nlog_survival * torch.where(y==self.pad, 0, y)
        loss += nlog_failure * torch.where(y==self.pad, 0, (1-y))
        return loss.sum() / (y != self.pad).sum()
    
    def mae_loss(self, s_hat, y):
        observed = (y == 0).any(1).int()
        t_hat = torch.where(y == self.pad, 0, s_hat).sum(1)
        t = torch.where(y == self.pad, 0, y).sum(1)
        zeros = torch.zeros(t.shape).cuda()
        observed_error = torch.abs(t_hat - t) * observed
        censored_error = torch.maximum(zeros, t - t_hat) * (1 - observed)
        return (observed_error.sum() + censored_error.sum()) / t.numel()

    def combined_loss(self, s_hat, y):
        a = self.alpha
        ordinal_loss = self.ordinal_survival_loss(s_hat, y)
        mae_loss = self.mae_loss(s_hat, y)
        return (1 - a) * ordinal_loss + a * mae_loss


#### Define the positional encoding class

In [8]:
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1), :].unsqueeze(0)
        return self.dropout(x)

#### Set up Weights and Biases

In [9]:
# Log in to your W&B account
# note, you may have to restart runtime after installing wandb
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

### Run the model

#### Setup for training

In [10]:
import math
import random
from torch.utils.data import DataLoader, random_split
# import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

# set the default parameters (use if you don't use wandb)
preprocess_seed = 30
train_frac = 0.7
val_frac = 0.15
accelerator = "gpu"
devices = 1
max_epochs = 5
batch_size = 32
train_seed = 0
causal = False
d_model = 32
n_blocks = 3
n_heads = 8
dropout = 0.1
pad = -100
dynamic = True
lr = 0.001
alpha = 0.01


# set up the wandb configuration, 
# which can be varied depending on the experiment
config = {
    "preprocess_seed": 30,
    "train_frac": .7,
    "val_frac": .15,
    "accelerator": "gpu",
    "devices": 1,
    "max_epochs": 5,
    "batch_size": 32,
    "train_seed": 0,
    "causal": False,
    "d_model": 32,
    "n_blocks": 3,
    "n_heads": 8,
    "dropout": .1, # random.uniform(0.01, 0.8),
    "pad": -100,
    "dynamic": True,
    "lr": 0.001,
    "alpha": 0.01,
}


# This call to the Mimi3Dataset class relies on having already run the preprocessing pipeline. If you have not run the preprocessing notebook, this will not work.
dataset = Mimic3Dataset(work_dir=cwd, seed=preprocess_seed)

train_size = int(train_frac * len(dataset))
if train_frac + val_frac == 1.0:
    val_size = len(dataset) - train_size
    test_size = 0
else:
    val_size = int(val_frac * len(dataset))
    test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(
    dataset,
    (train_size, val_size, test_size),
    torch.Generator().manual_seed(train_seed)
)

# set up the wandb logger
wandb_logger = WandbLogger(project="mimic3-survival", name="model_run", log_model="all")

    # wandb.init(project="mimic3-survival", config=params)
    # config = wandb.config

# set up the dataloaders
def collate(x):
    return padded_collate(x, pad_index=pad, causal=causal)
train_loader = DataLoader(
    train_set,
    collate_fn=collate,
    batch_size=batch_size,
    shuffle=True
    )
val_loader = DataLoader(
    val_set, collate_fn=collate, batch_size=batch_size
)
if test_size:
    test_loader = DataLoader(
        test_set, collate_fn=collate, batch_size=batch_size
    )

# set up the model
model = DST(n_codes=dataset.n_codes, n_vitals=dataset.n_vitals, n_demog=dataset.n_demog, d_model=d_model, n_blocks=n_blocks, n_heads=n_heads, dropout=dropout, pad=pad, dynamic=dynamic, lr=lr, alpha=alpha, causal=causal)

# set up the additional callbacks
callbacks = [ModelCheckpoint(monitor="val_mae_epoch", mode="min")]
# set up the trainer
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator=accelerator,
    devices=devices,
    max_epochs=max_epochs,
    callbacks=callbacks,
)
# train the model
trainer.fit(model, train_loader, val_loader)
# test the model
if test_size:
    trainer.test(dataloaders=test_loader)

# close the wandb run
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mconlinm[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
   | Name         | Type               | Params
-----------------------------------------------------
0  | embed_codes  | Linear             | 16.6 K
1  | embed_static | Linear             | 1.2 K 
2  | embed_vitals | Linear             | 832   
3  | pos_encode   | PositionalEncoding | 0     
4  | transformer  | TransformerEncoder | 38.2 K
5  | to_hazard_c  | Sequential         | 545   
6  | train_mae    | MeanAbsoluteError  | 0     
7  | val_mae      | MeanAbsoluteError  | 0     
8  | val_ci       | ConcordanceIndex   | 0     
9  | test_m

Sanity Checking: 0it [00:00, ?it/s]

  item["codes"] = torch.tensor(
  return torch._transformer_encoder_layer_fwd(


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./mimic3-survival/8g55q4t7/checkpoints/epoch=4-step=3320.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at ./mimic3-survival/8g55q4t7/checkpoints/epoch=4-step=3320.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='5.296 MB of 5.296 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇█
test_ci_epoch,▁
test_mae_epoch,▁
train_loss,▇▆█▅▄▅▄▄▅▇▃▂▃▃▄▂▃▁▅▅▄▃▃▂▂▆▃▄▃▃▃▂▅▆▁▄▁▃▄▂
train_mae,▅▆█▅▄▃▄▄▄▄▃▂▃▃▃▁▂▂▅▄▆▃▄▃▂▄▃▂▃▃▃▂▅▆▁▄▁▄▃▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
val_ci_epoch,▅▇▁██
val_loss,▅▂█▁▂
val_mae_epoch,▆▂█▁▁

0,1
epoch,5.0
test_ci_epoch,0.69482
test_mae_epoch,11.69167
train_loss,0.47062
train_mae,9.08891
trainer/global_step,3320.0
val_ci_epoch,0.67958
val_loss,0.56871
val_mae_epoch,11.97327


### Congifuration for Hyperparameter Sweep using Weights and Biases

In [None]:
sweep_config = {
  "method": "random",   # Random search
  "metric": {           # We want to maximize val_acc
      "name": "valid_acc",
      "goal": "maximize"
  },
  "parameters": {
        "n_layer_1": {
            # Choose from pre-defined values
            "values": [32, 64, 128, 256, 512]
        },
        "n_layer_2": {
            # Choose from pre-defined values
            "values": [32, 64, 128, 256, 512, 1024]
        },
        "lr": {
            # log uniform distribution between exp(min) and exp(max)
            "distribution": "log_uniform",
            "min": -9.21,   # exp(-9.21) = 1e-4
            "max": -4.61    # exp(-4.61) = 1e-2
        }
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="MNIST")