In [None]:
import argparse
import sys
import glob
import os
import polars as pl
import pickle

import lightning as L
import toml
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torch.utils.data import DataLoader
import csv
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from lightning.pytorch.callbacks import Callback
# nn.Modules
import torchmetrics

In [None]:
#!pip install torch --index-url https://download.pytorch.org/whl/cu121

In [None]:
# Check cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

#### Load model files

In [None]:
config = toml.load("../config/model.toml")
targets = toml.load('../config/targets.toml')
batch_size = config["data"]["batch_size"]
n_epochs = config["train"]["epochs"]
lr = config["train"]["learning_rate"]
num_workers = config["data"]["num_workers"]
fusion_method = config["model"]["fusion_method"]
st_first = config["model"]["st_first"] if fusion_method == "mag" else True
modalities = config["data"]["modalities"]
with_ts = config["model"]["with_ts"]
static_only = True if (len(modalities) == 1) and ("static" in modalities) else False
with_notes = True if "notes" in modalities else False
outcomes = targets["outcomes"]["labels"]
outcomes_disp = targets["outcomes"]["display"]
### Initial setup
ids_path = "../outputs/processed_data"
data_path = "../outputs/processed_data/mmfair_feat.pkl"
col_path = "../outputs/processed_data/mmfair_cols.pkl"
outcome = "in_hosp_death"

In [None]:
print(batch_size, n_epochs, lr, num_workers, fusion_method, st_first, modalities, with_ts, static_only, with_notes, outcomes, outcomes_disp)

In [None]:
def load_pickle(filepath: str):
    """Load a pickled object.

    Args:
        filepath (str): Path to pickle (.pkl) file.

    Returns:
        Any: Loaded object.
    """
    with open(filepath, "rb") as f:
        data = pickle.load(f)
    return data
embeddings = load_pickle("../outputs/processed_data/mmfair_feat.pkl")
#emb_new = load_pickle("../outputs/prep_data/mmfair_feat.pkl")
cols = load_pickle("../outputs/processed_data/mmfair_cols.pkl")

In [None]:
train_ids = pl.read_csv(os.path.join(ids_path, "training_ids_" + outcome + ".csv")).select("subject_id").to_numpy().flatten()
val_ids = pl.read_csv(os.path.join(ids_path, "validation_ids_" + outcome + ".csv")).select("subject_id").to_numpy().flatten()

### Downsample train and val data
train_ids = train_ids[:5000]
val_ids = val_ids[:1000]

In [None]:
train_ids.flatten()

In [None]:
class CollateFn:
    """Custom collate function for static data and labels."""

    def __call__(self, batch):
        static = torch.stack([data[0] for data in batch])
        labels = torch.stack([data[1] for data in batch])

        return static, labels


class CollateTimeSeries:
    """Custom collate function that can handle variable-length timeseries in a batch."""

    def __init__(self, method="pack_pad", min_events=None) -> None:
        self.method = method
        self.min_events = min_events

    def __call__(self, batch):
        static = torch.stack([data[0] for data in batch])
        labels = torch.stack([data[1] for data in batch])
        notes = None
        if len(batch[0]) > 3:  # noqa: PLR2004
            # pad notes to max length in batch
            notes = torch.stack([data[3] for data in batch])
            #notes = pad_sequence([data[3] for data in batch], batch_first=True)

        # number of dynamic timeseries data (note: dynamic is a list of timeseries)
        n_ts = len(batch[0][2])
        #print("Number of timeseries", n_ts)

        if self.method == "pack_pad":
            dynamic = []
            lengths = []
            for ts in range(n_ts):
                # Function to pad batch-wise due to timeseries of different lengths
                timeseries_lengths = [data[2][ts].shape[0] for data in batch]
                #print("Timeseries lengths", timeseries_lengths)
                max_events = max(timeseries_lengths)
                #print("Max events", max_events)
                n_ftrs = batch[0][2][ts].shape[1]
                events = torch.zeros((len(batch), max_events, n_ftrs))
                for i in range(len(batch)):
                    j, k = batch[i][2][ts].shape[0], batch[i][2][ts].shape[1]
                    events[i] = torch.concat(
                        [batch[i][2][ts], torch.zeros((max_events - j, k))]
                    )
                dynamic.append(events)
                lengths.append(timeseries_lengths)

            if notes is not None:
                return static, labels, dynamic, lengths, notes
            else:
                return static, labels, dynamic, lengths

        elif self.method == "truncate":
            # Truncate to minimum num of events in batch/ specified args

            dynamic = []
            n_ts = len(batch[0][2])
            for ts in range(n_ts):
                min_events = (
                    min([data[2][ts].shape[0] for data in batch])
                    if self.min_events is None
                    else self.min_events
                )
                events = [data[2][ts][:min_events] for data in batch]
                dynamic.append(events)
                
            if notes is not None:
                return static, labels, dynamic, lengths, notes
            else:
                return static, labels, dynamic, lengths


class MIMIC4Dataset(Dataset):
    """MIMIC-IV Dataset class. Subclass of Pytorch Dataset.
    Reads from .pkl data dictionary where key is patient ID and values are the dataframes.
    """

    def __init__(
        self,
        data_path=None,
        col_path=None,
        split=None,
        ids=None,
        static_only=False,
        with_notes=False,
        outcome="in_hosp_death"
    ) -> None:
        super().__init__()

        self.data_dict = load_pickle(data_path)
        self.col_dict = load_pickle(col_path)
        self.id_list = list(self.data_dict.keys()) if ids is None else ids
        #print(self.data_dict[self.id_list[0][0]].keys())
        self.dynamic_keys = sorted([key for key in self.data_dict[self.id_list[0]].keys() if "dynamic" in key])
        self.split = split
        self.static_only = static_only
        self.with_notes = with_notes
        self.splits = {"train": None, "val": None, "test": None}
        self.outcome = outcome
        self.splits[split] = ids

    def __len__(self):
        return (
            len(self.splits[self.split])
            if self.split is not None
            else len(self.id_list)
        )

    def __getitem__(self, idx):
        pt_id = int(self.splits[self.split][idx])
        static = self.data_dict[pt_id]["static"]
        label = torch.tensor(
            self.data_dict[pt_id][self.outcome][0][0], dtype=torch.float32
        ).unsqueeze(-1)
        static = torch.tensor(static, dtype=torch.float32)

        if self.static_only:
            return static, label

        else:
            dynamic = [
                self.data_dict[pt_id][i] for i in self.dynamic_keys
            ]
            dynamic = [torch.tensor(x, dtype=torch.float32) for x in dynamic]
            if self.with_notes:
                notes = self.data_dict[pt_id]["notes"]  # 1 x 768
                ### Extract tokens only
                emblist = []
                for emb in notes:
                    emblist.append(emb[1])
                    
                notes = torch.tensor(emblist, dtype=torch.float32).unsqueeze(0)
                notes = torch.nn.functional.pad(notes, (0, 768 - notes.shape[1]))
                return static, label, dynamic, notes
            else:
                return static, label, dynamic

    def print_label_dist(self):
        # if no particular split then use entire data dict
        if self.split is None:
            id_list = self.id_list
        else:
            id_list = self.splits[self.split]

        #print(id_list[0], id_list.shape[0], len(id_list[0]))
        #print(self.data_dict[id_list[0][0]][self.outcome])
        #print(self.data_dict[id_list[0][0]][self.outcome][0][0])

        n_positive = len([id_list[i] for i in range(len(id_list)) if self.data_dict[id_list[i]][self.outcome][0][0] == 1])

        if self.split is not None:
            print(f"{self.split.upper()}:")

        print(f"Positive cases: {n_positive}")
        print(
            f"Negative cases: {self.id_list.shape[0] - n_positive}"
        )

    def get_feature_dim(self, key="static"):
        return self.data_dict[int(self.id_list[0])][key].shape[1]

    def get_feature_list(self, key="static"):
        return self.col_dict[key + "_cols"]

    def get_split_ids(self, split):
        return self.splits[split]

In [None]:



class LSTM(nn.Module):
    def __init__(self, input_dim, embed_dim, num_layers=1, hidden_dim=128, dropout=0):
        super().__init__()
        self.lstm = nn.LSTM(
            input_dim, hidden_dim, num_layers=num_layers, batch_first=True
        )
        self.project = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, input_):
        _, (h_T, _) = self.lstm(input_)
        output = self.dropout(self.project(h_T[-1]))
        return self.relu(output)


class Gate(nn.Module):
    # Adapted from https://github.com/emnlp-mimic/mimic/blob/main/base.py#L136 inspired by https://arxiv.org/pdf/1908.05787
    def __init__(self, inp1_size, inp2_size, inp3_size: int = 0, dropout: int = 0):
        super().__init__()

        self.fc1 = nn.Linear(inp1_size + inp2_size, 1)
        self.fc2 = nn.Linear(inp1_size + inp3_size, 1)
        self.fc3 = nn.Linear(inp2_size + inp3_size, inp1_size)
        self.beta = nn.Parameter(torch.randn((1,)))
        self.norm = nn.LayerNorm(inp1_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inp1, inp2, inp3=None):
        w2 = torch.sigmoid(self.fc1(torch.cat([inp1, inp2], -1)))
        if inp3 is not None:
            w3 = torch.sigmoid(self.fc2(torch.cat([inp1, inp3], -1)))
            adjust = self.fc3(torch.cat([w2 * inp2, w3 * inp3], -1))
        else:
            # only need to adjust input 2
            adjust = self.fc3(w2 * inp2)

        one = torch.tensor(1).type_as(adjust)
        alpha = torch.min(torch.norm(inp1) / torch.norm(adjust) * self.beta, one)
        output = inp1 + alpha * adjust
        output = self.dropout(self.norm(output)).squeeze()
        return output


# lightning.LightningModules
class MMModel(L.LightningModule):
    def __init__(
        self,
        st_input_dim=18,
        st_embed_dim=64,
        ts_input_dim=(9, 7),
        ts_embed_dim=64,
        nt_input_dim=768,
        nt_embed_dim=64,
        num_layers=1,
        dropout=0.1,
        num_ts=2,
        target_size=1,
        lr=0.1,
        fusion_method="concat",
        st_first=True,
        with_ts=False,
        with_notes=False,
        with_packed_sequences=False,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.num_ts = num_ts
        self.with_ts = with_ts
        self.with_notes = with_notes
        self.st_first = st_first
        self.fusion_method = fusion_method

        self.embed_static = nn.Sequential(
            nn.Linear(st_input_dim, st_embed_dim // 2),
            nn.LayerNorm(st_embed_dim // 2),
            nn.Linear(st_embed_dim // 2, st_embed_dim),
            nn.Dropout(dropout),
            nn.ReLU(),
        )

        if self.with_ts:
            #print(ts_input_dim.shape)
            #print(ts_input_dim)
            #print(ts_embed_dim)
            self.embed_timeseries = nn.ModuleList(
                [
                    LSTM(
                        ts_input_dim[i],
                        ts_embed_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                    )
                    for i in range(self.num_ts)
                ]
            )

        if self.with_notes:
            self.embed_notes = nn.Linear(nt_input_dim, nt_embed_dim)
        else:
            self.embed_notes = None
            nt_embed_dim = 0

        if self.fusion_method == "mag":
            if self.st_first:
                self.fuse = Gate(
                    st_embed_dim, *([ts_embed_dim] * self.num_ts), dropout=dropout
                )
                self.fc = nn.Linear(st_embed_dim, target_size)

            else:
                self.fuse = Gate(
                    *([ts_embed_dim] * self.num_ts), st_embed_dim, dropout=dropout
                )
                self.fc = nn.Linear(ts_embed_dim, target_size)

        elif self.fusion_method == "concat":
            # embeddings must be same dim
            assert st_embed_dim == ts_embed_dim
            if self.with_notes:
                assert nt_embed_dim == st_embed_dim
            self.fc = nn.Linear(
                st_embed_dim + (self.num_ts * ts_embed_dim) + nt_embed_dim, target_size
            )

        elif self.fusion_method == "None":
            self.fc = nn.Linear(st_embed_dim, target_size)

        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.lr = lr
        self.acc = torchmetrics.Accuracy(task="binary")
        self.auc = torchmetrics.AUROC(task="binary")
        self.f1 = torchmetrics.F1Score(task="binary")
        self.ap = torchmetrics.AveragePrecision(task="binary")

        self.with_packed_sequences = with_packed_sequences

    def prepare_batch(self, batch):  # noqa: PLR0912
        # static, labels, dynamic, lengths, notes (optional) # noqa: E741
        s = batch[0]
        y = batch[1]

        if self.with_ts:
            d = batch[2]
            if self.with_packed_sequences:
                lengths = batch[3]

        if self.fusion_method != "None":
            #print('Packing padded sequences.')
            ts_embed = []
            for i in range(self.num_ts):
                if self.with_packed_sequences:
                    packed_d = torch.nn.utils.rnn.pack_padded_sequence(
                        d[i], lengths[i], batch_first=True, enforce_sorted=False
                    )
                    embed = self.embed_timeseries[i](packed_d)
                else:
                    embed = self.embed_timeseries[i](d[i])

                ts_embed.append(embed.unsqueeze(1))

        if self.with_notes:
            n = batch[4]
            nt_embed = self.embed_notes(n)
        else:
            nt_embed = None

        st_embed = self.embed_static(s)

        # Fuse time-series and static data
        if self.fusion_method == "concat":
            # use * to allow variable number of ts_embeddings
            # concat along feature dim
            embeddings = [st_embed, *ts_embed]
            embeddings = embeddings + [nt_embed] if nt_embed is not None else embeddings
            out = torch.concat(embeddings, dim=-1).squeeze()  # b x dim*2
        elif self.fusion_method == "mag":
            if self.st_first:
                out = self.fuse(st_embed, *ts_embed)  # b x st_embed_dim
            else:
                out = self.fuse(*ts_embed, st_embed)

        elif self.fusion_method == "None":
            # print('No fusion method specified. Using static data only.')
            out = st_embed.squeeze()

        # Parse through FC
        x_hat = self.fc(out)  # b x 1 - logits
        if len(x_hat.shape) < 2:  # noqa: PLR2004
            x_hat = x_hat.unsqueeze(0)
        return x_hat, y

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x_hat, y = self.prepare_batch(batch)  # logit
        y_hat = torch.sigmoid(x_hat)  # prob
        loss = self.criterion(x_hat, y)
        accuracy = self.acc(y_hat, y)
        auc = self.auc(y_hat, y)
        f1 = self.f1(y_hat, y)
        ap = self.ap(y_hat, y.long())

        self.log(
            "train_loss",
            loss,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            batch_size=len(y),
        )
        self.log(
            "train_acc",
            accuracy,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            batch_size=len(y),
        )
        self.log(
            "train_auc",
            auc,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            batch_size=len(y),
        )
        self.log(
            "train_f1",
            f1,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            batch_size=len(y),
        )
        self.log(
            "train_ap",
            ap,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            batch_size=len(y),
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x_hat, y = self.prepare_batch(batch)
        y_hat = torch.sigmoid(x_hat)
        loss = self.criterion(x_hat, y)
        accuracy = self.acc(y_hat, y)
        auc = self.auc(y_hat, y)
        f1 = self.f1(y_hat, y)
        ap = self.ap(y_hat, y.long())
        self.log("val_loss", loss, prog_bar=True, batch_size=len(y))
        self.log("val_acc", accuracy, prog_bar=True, batch_size=len(y))
        self.log("val_auc", auc, prog_bar=True, batch_size=len(y))
        self.log("val_f1", f1, prog_bar=True, batch_size=len(y))
        self.log("val_ap", ap, prog_bar=True, batch_size=len(y))

    def predict_step(self, batch):
        x_hat, y = self.prepare_batch(batch)
        y_hat = torch.sigmoid(x_hat)
        return y_hat, y

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=20)
        return [optimizer], [
            {"scheduler": scheduler, "monitor": "val_loss", "interval": "epoch"}
        ]

class LitLSTM(L.LightningModule):
    """LSTM using time-series data only.

    Args:
        L (_type_): _description_
    """

    def __init__(
        self,
        ts_input_dim,
        lstm_embed_dim,
        target_size,
        lr=0.1,
        with_packed_sequences=False,
    ):
        super().__init__()
        self.embed_timeseries = LSTM(
            ts_input_dim,
            lstm_embed_dim,
            target_size,
            with_packed_sequences=with_packed_sequences,
        )
        self.fc = nn.Linear(lstm_embed_dim, target_size)

        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.lr = lr
        self.acc = torchmetrics.Accuracy(task="binary")
        self.with_packed_sequences = with_packed_sequences

    def prepare_batch(self, batch):
        if self.with_packed_sequences:
            _, y, d, l = batch  # static, dynamic, lengths, labels  # noqa: E741
            d = torch.nn.utils.rnn.pack_padded_sequence(
                d, l, batch_first=True, enforce_sorted=False
            )

        else:
            _, y, d = batch

        ts_embed = self.embed_timeseries(d)

        # unpack if using packed sequences
        if self.with_packed_sequences:
            lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(
                ts_embed, batch_first=True
            )

        # [:, -1] for hidden state at the last time step
        logits = self.fc(lstm_out[:, -1])
        x_hat = F.sigmoid(logits)
        return x_hat, y

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x_hat, y = self.prepare_batch(batch)
        loss = self.criterion(x_hat, y)
        accuracy = self.acc(x_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x_hat, y = self.prepare_batch(batch)
        loss = self.criterion(x_hat, y)
        accuracy = self.acc(x_hat, y)
        self.log("val_loss", loss, prog_bar=True, batch_size=len(y))
        self.log("val_acc", accuracy, prog_bar=True, batch_size=len(y))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
class SaveLossesCallback(Callback):
    def __init__(self, log_dir="logs", save_every_n_epochs=5):
        """
        Callback to save train/validation losses to a CSV file every n epochs.

        Args:
            log_dir (str): Directory to save the logs.
            save_every_n_epochs (int): Interval (in epochs) to save the losses.
        """
        self.log_dir = log_dir
        self.save_every_n_epochs = save_every_n_epochs
        os.makedirs(self.log_dir, exist_ok=True)
        self.csv_file = os.path.join(self.log_dir, "losses.csv")

        # Initialize the CSV file with headers if it doesn't exist
        if not os.path.exists(self.csv_file):
            with open(self.csv_file, mode="w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["Epoch", "Train Loss", "Validation Loss"])

    def on_train_epoch_end(self, trainer):
        # Save losses every n epochs
        if (trainer.current_epoch + 1) % self.save_every_n_epochs == 0:
            train_loss = trainer.callback_metrics.get("train_loss", None)
            val_loss = trainer.callback_metrics.get("val_loss", None)

            # Append the losses to the CSV file
            with open(self.csv_file, mode="a", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([trainer.current_epoch + 1, train_loss, val_loss])

            print(f"Saved losses to {self.csv_file}")

#### Test pipeline

In [None]:
print(static_only, with_notes, with_ts)

In [None]:
training_set = MIMIC4Dataset(
        data_path,
        col_path,
        "train",
        ids=train_ids,
        static_only=static_only,
        with_notes=with_notes,
)
validation_set = MIMIC4Dataset(
        data_path,
        col_path,
        "val",
        ids=val_ids,
        static_only=static_only,
        with_notes=with_notes,
)
training_set.print_label_dist()
n_static_features = (
        training_set.get_feature_dim()
)  # add -1 if dropping label col

if not static_only:
        n_dynamic_features = (
        training_set.get_feature_dim("dynamic_0"),
        training_set.get_feature_dim("dynamic_1"),
        )
        print(n_dynamic_features)
        print(n_static_features)
        n_val_features = (
        validation_set.get_feature_dim("static"),
        validation_set.get_feature_dim("dynamic_0"),
        validation_set.get_feature_dim("dynamic_1")
        )
        print(n_val_features)
else:
        n_dynamic_features = (None, None)

In [None]:
#print(len(training_set))
#data = training_set[2]
static, label, dynamic, notes = training_set[2]

print("Static shape:", static.shape)
print("Label shape:", label.shape)
for i, ts in enumerate(dynamic):
    print(f"Dynamic modality {i} shape:", ts.shape)
print("Notes shape:", notes.shape)

print(notes)
print(static)

In [None]:
#print(len(training_set))
#data = training_set[2]
static, label, dynamic, notes = validation_set[2]

print("Static shape:", static.shape)
print("Label shape:", label.shape)
for i, ts in enumerate(dynamic):
    print(f"Dynamic modality {i} shape:", ts.shape)
print("Notes shape:", notes.shape)

print(notes)
print(static)

In [None]:
training_dataloader = DataLoader(
        training_set,
        batch_size=batch_size,
        num_workers=0,
        collate_fn=CollateFn() if static_only else CollateTimeSeries(),
)
val_dataloader = DataLoader(
        validation_set,
        batch_size=batch_size,
        num_workers=0,
        collate_fn=CollateFn() if static_only else CollateTimeSeries(),
)
model = MMModel(
        st_input_dim=n_static_features,
        ts_input_dim=n_dynamic_features,
        with_packed_sequences=True if not static_only else False,
        fusion_method=fusion_method,
        with_notes=with_notes,
        with_ts=with_ts,
        st_first=st_first,
)
early_stop = EarlyStopping(monitor="val_loss", mode="min", patience=5)

mod_str = "_".join(modalities)
checkpoint = ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        filename=f"{outcome}_{fusion_method}_{mod_str}",
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
#save_losses_callback = SaveLossesCallback(log_dir=f"logs/{outcome}_{fusion_method}_{mod_str}/", save_every_n_epochs=5)

trainer = L.Trainer(
        max_epochs=n_epochs,
        accelerator='gpu',
        callbacks=[early_stop, checkpoint, lr_monitor],
)

In [None]:
trainer.fit(
        model=model,
        train_dataloaders=training_dataloader,
        val_dataloaders=val_dataloader,
)