In [1]:
import datetime
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

import numpy as np
import numba
import scipy
import scipy.sparse
import torch
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pandas as pd
import wandb

torch.set_float32_matmul_precision('high')

In [2]:
# Generate synthetic data for testing
def generate_synthetic_test_dataset(
        num_individuals = 10000,
        num_records = 1000,
        num_endpoints = 100,
        mean_observation_time = 10
):
    # simulate records with varying frequencies
    record_frequencies = np.random.rand(num_records)
    records = [np.random.randn(num_individuals, 1) < freq for freq in record_frequencies]
    records = np.concatenate(records, axis=1)

    observation_times = np.random.randn(num_individuals) + mean_observation_time
    observation_times = np.clip(observation_times, 1, None)

    # simulate events with varying frequencies
    # exponential distribution lambdas with mean event times in [0, 200] years 
    lambdas_inv = np.random.rand(1, num_endpoints)
    # make lambdas depend on records (so that the model can learn something)
    weights = np.random.randn(num_records, num_endpoints)
    lambdas_pers_inv = np.clip(lambdas_inv * 200 + (records @ weights), 0, None)
    lambdas = 1 / (lambdas_pers_inv + 0.001)
    # sample random event times for each person and endpoint
    label_times = np.random.exponential(scale=1 / lambdas, size=(num_individuals, num_endpoints))
    # censoring
    label_events = label_times < observation_times[:, None]

    censorings = observation_times
    # assume no prior events for synthetic data
    exclusions = np.zeros((num_individuals, num_endpoints), dtype=bool)

    endpoint_names = [f'endpoint_{i}' for i in range(num_endpoints)]

    records = scipy.sparse.csr_matrix(records)
    label_times = scipy.sparse.csr_matrix(label_times)
    label_events = scipy.sparse.csr_matrix(label_events)
    exclusions = scipy.sparse.csr_matrix(exclusions)

    return records, label_events, label_times, censorings, exclusions, endpoint_names

In [3]:
records, label_events, label_times, censorings, exclusions, endpoint_names = generate_synthetic_test_dataset()

In [4]:
records.shape, label_events.shape, label_times.shape, censorings.shape, exclusions.shape

((10000, 1000), (10000, 100), (10000, 100), (10000,), (10000, 100))

In [5]:
class CoxPHLoss(torch.nn.Module):
    def forward(self, logh, durations, events, eps=1e-7):
        batch_size, endpoints = durations.shape
        order = durations.sort(descending=True, dim=0)[1]
        offset = torch.tensor(
            [order.shape[0] * i for i in range(order.shape[1])], device=order.device
        )
        order = (order + offset[None, :]).flatten()
        f_reorder = lambda arr: rearrange(
            rearrange(arr, "b e -> (e b)", b=batch_size)[order], "(b e) -> b e", b=batch_size
        )
        logh = f_reorder(logh)
        events = f_reorder(events)
        gamma = logh.max(0)[0]
        log_cumsum_h = logh.sub(gamma).exp().cumsum(0).add(eps).log().add(gamma)
        s_sum = events.sum(0)
        s_sum[s_sum == 0] = 1
        return -logh.sub(log_cumsum_h).mul(events).sum(0).div(s_sum)


@numba.njit(parallel=False, nogil=True)
def cindex(events, event_times, predictions):
    idxs = np.argsort(event_times)

    events = events[idxs]
    event_times = event_times[idxs]
    predictions = predictions[idxs]

    n_concordant = 0
    n_comparable = 0

    for i in numba.prange(len(events)):
        for j in range(i + 1, len(events)):
            if events[i] and events[j]:
                n_comparable += 1
                n_concordant += (event_times[i] > event_times[j]) == (
                    predictions[i] > predictions[j]
                )
            elif events[i]:
                n_comparable += 1
                n_concordant += predictions[i] < predictions[j]

    if n_comparable > 0:
        return n_concordant / n_comparable
    else:
        return np.nan

In [6]:
class RecordsDataset(Dataset):
    """
    PyTorch Dataset for loading medical records data.

    Args:
        records (scipy.sparse.csr_matrix): Sparse matrix of medical records.
        exclusions (scipy.sparse.csr_matrix): Sparse matrix of exclusions.
        labels_events (scipy.sparse.csr_matrix): Sparse matrix of event labels.
        labels_times (scipy.sparse.csr_matrix): Sparse matrix of time labels.
        censorings (Optional[np.array], optional): Array of censoring times. Defaults to None.
    """

    def __init__(
        self,
        records: scipy.sparse.csr_matrix,
        exclusions: scipy.sparse.csr_matrix,
        labels_events: scipy.sparse.csr_matrix,
        labels_times: scipy.sparse.csr_matrix,
        censorings: Optional[np.array] = None,
    ):
        self.records = records
        self.exclusions = exclusions
        self.labels_events = labels_events
        self.labels_times = labels_times
        self.censorings = censorings

    def __len__(self):
        return self.records.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        records = torch.Tensor(self.records[idx].todense()).squeeze().bool().float()
        exclusions = torch.Tensor(self.exclusions[idx].todense()).squeeze()
        labels_events = torch.Tensor(self.labels_events[idx].todense()).squeeze()
        labels_times = torch.Tensor(self.labels_times[idx].todense()).squeeze()

        censorings = None
        if self.censorings is not None:
            if not isinstance(idx, list):
                idx = [idx]
            censorings = torch.Tensor(self.censorings[idx]).squeeze()

        return dict(
            records=records,
            labels_events=labels_events,
            labels_times=labels_times,
            censorings=censorings,
            exclusions=exclusions,
        )


class RecordsDataModule(LightningDataModule):
    """
    PyTorch Lightning DataModule for loading medical records data.

    Args:
        records (scipy.sparse.csr_matrix): Sparse matrix of medical records.
        exclusions (scipy.sparse.csr_matrix): Sparse matrix of exclusions.
        labels_events (scipy.sparse.csr_matrix): Sparse matrix of event labels.
        labels_times (scipy.sparse.csr_matrix): Sparse matrix of time labels.
        censorings (Optional[np.array], optional): Array of censoring times. Defaults to None.
        indices (Tuple[np.array, np.array, np.array]): Tuple of train, validation, and test indices.
        batch_size (int): Batch size for data loading.

    Attributes:
        data_train (RecordsDataset): Training dataset.
        data_val (RecordsDataset): Validation dataset.
        data_test (RecordsDataset): Test dataset.

    Methods:
        setup(stage: str): Splits the data into train, validation, and test sets.
        train_dataloader(): Returns a DataLoader for the training set.
        val_dataloader(): Returns a DataLoader for the validation set.
    """
    def __init__(
        self, records, label_events, label_times, exclusions, censorings, indices, batch_size
    ):
        super().__init__()

        self.records = records
        self.label_events = label_events
        self.label_times = label_times
        self.exclusions = exclusions
        self.censorings = censorings
        self.indices = indices
        self.batch_size = batch_size

    def setup(self, stage: str):
        train_idxs, val_idxs, test_idxs = self.indices

        self.data_train = RecordsDataset(
            self.records[train_idxs],
            self.exclusions[train_idxs],
            self.label_events[train_idxs],
            self.label_times[train_idxs],
            self.censorings[train_idxs],
        )

        self.data_val = RecordsDataset(
            self.records[val_idxs],
            self.exclusions[val_idxs],
            self.label_events[val_idxs],
            self.label_times[val_idxs],
            self.censorings[val_idxs],
        )

        self.data_test = RecordsDataset(
            self.records[test_idxs],
            self.exclusions[test_idxs],
            self.label_events[test_idxs],
            self.label_times[test_idxs],
            self.censorings[test_idxs],
        )

    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=self.batch_size, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=self.batch_size, drop_last=True)

    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=self.batch_size, drop_last=True)

In [7]:
class MedicalHistoryModule(LightningModule):
    """
    PyTorch Lightning module for training a medical history model using Cox proportional hazards loss.

    Args:
        model (torch.nn.Module): The medical history model to train.
        endpoint_names (list): A list of endpoint names for the model.
        exclusions_on_metrics (bool, optional): Whether to exclude samples with exclusions when computing metrics. Defaults to True.
        lr (float, optional): The learning rate for the optimizer. Defaults to 0.005.
        weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
        n_chunks (int, optional): The number of chunks to split the data into when computing the loss. Defaults to 1.
    """
    def __init__(
        self,
        model: torch.nn.Module,
        endpoint_names: list,
        exclusions_on_metrics: bool = True,
        lr: float = 0.005,
        weight_decay: float = 0.01,
    ):
        super().__init__()

        self.loss = CoxPHLoss()

        self.model = model
        self.endpoint_names = endpoint_names
        self.num_endpoints = len(endpoint_names)
        self.exclusions_on_metrics = exclusions_on_metrics

        self.lr = lr
        self.weight_decay = weight_decay

        self.executor = ThreadPoolExecutor(max_workers=multiprocessing.cpu_count())
        self.max_mean_cindex = 0

        self.valid_data = [([], [], []) for _ in range(self.num_endpoints)]

        self.apply(self._init_weights)
        self.save_hyperparameters(ignore=["model"])

    def get_loss(self, batch, predictions):
        ehr_events = batch["labels_events"].bool()
        ehr_times = batch["labels_times"]
        ehr_censorings = batch["censorings"][:, None].repeat(1, self.num_endpoints)
        # set event time to censoring time for non-events
        ehr_times[~ehr_events] = ehr_censorings[~ehr_events]

        losses = self.loss(predictions, ehr_times.squeeze(), ehr_events.squeeze())
        loss = torch.mean(losses)

        return loss

    def shared_step(self, batch, batch_idx):
        ehr_records = batch["records"]
        ehr_records = ehr_records.bool().float()

        latents, predictions = self.model(ehr_records)

        return latents, predictions, self.get_loss(batch, predictions)

    def training_step(self, batch, batch_idx):
        latents, predictions, loss = self.shared_step(batch, batch_idx)
        self.log("train/loss", loss.item(), batch_size=len(predictions))

        return loss

    def validation_step(self, batch, batch_idx):
        latents, predictions, loss = self.shared_step(batch, batch_idx)

        self.log("valid/loss", loss.item(), batch_size=len(predictions))

        events = batch["labels_events"].detach().cpu()
        times = batch["labels_times"].detach().cpu()
        exclusions = batch["exclusions"].detach().cpu()

        for endpoint_idx in range(self.num_endpoints):
            if self.exclusions_on_metrics:
                mask = exclusions[:, endpoint_idx] == 0

                predictions_ = predictions[mask, endpoint_idx]
                events_ = events[mask, endpoint_idx]
                times_ = times[mask, endpoint_idx]
            else:
                predictions_ = predictions[:, endpoint_idx]
                events_ = events[:, endpoint_idx]
                times_ = times[:, endpoint_idx]

            self.valid_data[endpoint_idx][0].append(predictions_.detach().cpu().float().numpy())
            self.valid_data[endpoint_idx][1].append(events_.numpy())
            self.valid_data[endpoint_idx][2].append(times_.numpy())

        return loss

    def on_validation_epoch_end(self) -> None:
        def compute(valid_data):
            preds = np.concatenate(valid_data[0]).squeeze()
            events = np.concatenate(valid_data[1]).astype(bool)
            times = np.concatenate(valid_data[2]).squeeze()

            return cindex(events, times, 1 - preds)

        cindices = list(self.executor.map(compute, self.valid_data))
        for endpoint_idx in range(self.num_endpoints):
            cidx = cindices[endpoint_idx]
            endpoint_name = (
                self.endpoint_names[endpoint_idx] if self.endpoint_names else endpoint_idx
            )
            self.log(f"valid/cindex_{endpoint_name}", cidx)

        self.valid_data = [([], [], []) for _ in range(self.num_endpoints)]

        self.log(f"valid/mean_cindex", np.nanmean(cindices))

        self.max_mean_cindex = max(np.nanmean(cindices), self.max_mean_cindex)
        self.log(f"valid/mean_cindex_max", self.max_mean_cindex)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        return optimizer

    def _init_weights(self, m):
        if isinstance(m, (torch.nn.LayerNorm, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
            torch.nn.init.constant_(m.weight, 1)
            torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, (torch.nn.Conv2d, torch.nn.Conv3d)):
            torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        elif isinstance(m, torch.nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight, mode="fan_out")
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

In [8]:
class MedicalHistoryModel(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hidden, dropout_input, dropout_hidden):
        super().__init__()

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs

        self.model = torch.nn.Sequential(
            torch.nn.Dropout(dropout_input),
            torch.nn.Linear(num_inputs, num_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.LayerNorm(num_hidden),

            torch.nn.Dropout(dropout_hidden),
            torch.nn.Linear(num_hidden, num_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.LayerNorm(num_hidden),

            torch.nn.Dropout(dropout_hidden),
        )

        # Define the output layer
        self.head = torch.nn.Linear(num_hidden, num_outputs, bias=False)

    def forward(self, records):
        # Compute the latent representation of the input records
        latents = self.model(records)

        # Compute the model predictions
        predictions = self.head(latents)

        return latents, predictions

In [9]:
# Set the batch size for training
batch_size = 2048

# Get the number of inputs and outputs for the model
num_inputs = records.shape[1]
num_outputs = len(endpoint_names)

# Set the number of hidden units and dropout rates for the model
num_hidden = 4096
dropout_input = 0.18
dropout_hidden = 0.85

# Initialize a MedicalHistoryModel object with the specified settings
model = MedicalHistoryModel(num_inputs, num_outputs, num_hidden, dropout_input, dropout_hidden)

# Initialize a MedicalHistoryModule object with the specified settings
module = MedicalHistoryModule(model, endpoint_names, lr=0.000628, weight_decay=0.3)

In [10]:
# random train, valid, test split of data
num_individuals = records.shape[0]

idxs = np.arange(num_individuals)
np.random.shuffle(idxs)

train_idxs = idxs[: int(0.8 * num_individuals)]
val_idxs = idxs[int(0.8 * num_individuals) : int(0.9 * num_individuals)]
test_idxs = idxs[int(0.9 * num_individuals) :]

indices = (train_idxs, val_idxs, test_idxs)

data = RecordsDataModule(
    records,
    label_events,
    label_times,
    exclusions,
    censorings,
    indices,
    batch_size=batch_size,
)

In [11]:
# Get today's date in YYMMDD format
date = datetime.date.today().strftime("%y%m%d")

# Set the name of the run
run_name = "simulated_test"

# Initialize a WandbLogger object with the specified settings
wandb_logger = WandbLogger(
    name=f"{date}_{run_name}",
    project="medhist_simulated",
    log_model=True,
    settings=wandb.Settings(start_method="thread"),
    notes=repr(model),
)

# Watch the model for logging purposes
wandb_logger.watch(model, log="all")

# Initialize a Trainer object with the specified settings
trainer = Trainer(
    logger=wandb_logger,
    callbacks=[
        EarlyStopping(monitor="valid/mean_cindex", mode="max", min_delta=1e-8, patience=20),
        ModelCheckpoint(mode="max", monitor="valid/mean_cindex", save_top_k=1, save_last=True),
        LearningRateMonitor(logging_interval="step"),
    ],
    log_every_n_steps=10,
    val_check_interval=0.25,
    accelerator="auto", 
    devices="auto", 
    strategy="auto",
    max_epochs=10,
)

# Train the model using the Trainer object and the LightningModule and LightningDataModule objects
trainer.fit(module, datamodule=data)

# Finish the logging process
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnebw[0m ([33mcardiors[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params
----------------------------------------------
0 | loss  | CoxPHLoss           | 0     
1 | model | MedicalHistoryModel | 21.3 M
----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.230    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


wandb: ERROR Error uploading "/home/wildb/.local/share/wandb/artifacts/staging/tmpm4yrgw63": OSError, [Errno 28] No space left on device: '/sc-scratch/sc-scratch-ukb-cvd/.cache/wandb/artifacts/obj/md5/99/tmp_AFE6C0FF'
wandb: ERROR Uploading artifact file failed. Artifact won't be committed.


0,1
epoch,▁▅█
lr-AdamW,▁▁▁
train/loss,█▄▁
trainer/global_step,▁▁▅▅██

0,1
epoch,9.0
lr-AdamW,0.00063
train/loss,73.86647
trainer/global_step,29.0
