Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning example for CoxPH loss #78

Open
wants to merge 10 commits into
base: refactor_out_torchtuples
Choose a base branch
from
213 changes: 213 additions & 0 deletions examples/lightning_coxph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
'''
A minimal example of how to fit a cox_PH model with pytorch lightning independent of torchtuples
Original author: Rohan Shad @rohanshad
'''
from typing import Tuple
import numpy as np
import pandas as pd
import torch

import pytorch_lightning as pl
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from pycox.datasets import metabric
from pycox.evaluation import EvalSurv
from pycox.models import coxph

# For preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

# Lightning Dataset Module

class MetaBrick(pl.LightningDataModule):
'''
Prepares metabric dataset for either discrete time or cox proportional models.
batch_size (int) - batch size, default = 256
num_durations (int) - number of timepoints to discretize data into (for discrete time models only), default = 10
num_workers (int) - number of cpu workers to load data, default = 0
discretize (bool) - Whether or not to discretize data (set as True only for discrete time models), default = False
'''
def __init__(self, batch_size: int = 256, num_durations: int = 10, num_workers: int = 0, discretize: bool = False):
super().__init__()
self.batch_size = batch_size
self.num_durations = num_durations
self.num_workers = num_workers
self.discretize = discretize

def setup(self, stage=None):
'''
Get the METABRICK dataset split into a training dataframe and a testing dataframe.
Preprocesses features and targets (duration and event), discretize time into 'num_duration' equidistant points.
'''

# Load and split dataset into train and test (if there's train and val this can be called within stage == 'fit')
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

self.x_train, self.x_val, self.x_test = self._preprocess_features(df_train, df_val, df_test)

if stage == 'fit' or stage is None:
# Setup targets (duration, event)
self.y_train = torch.from_numpy(np.concatenate(self._get_target(df_train), axis=1))
self.y_val = torch.from_numpy(np.concatenate(self._get_target(df_val), axis=1))

# Create training and validation datasets
self.train_set = TensorDataset(self.x_train, self.y_train)
self.val_set = TensorDataset(self.x_val, self.y_val)

# Input and output dimensions for building net
self.in_dims = self.x_train.shape[1]
self.out_dims = 1

if stage == 'test' or stage is None:
# Returns correctly preprocessed target y_test {torch.Tensor} and entire df_test {pd.DataFrame} for metric calculations
self.y_test = torch.from_numpy(np.concatenate(self._get_target(df_test), axis=1))
self.df_test = df_test


def train_dataloader(self):
'''
Build training dataloader
num_workers set to 0 by default because of some thread issue
'''
train_loader = DataLoader(
dataset=self.train_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)
return train_loader

@classmethod
def _preprocess_features(cls, df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame) -> Tuple[torch.Tensor]:
'''
Preprocess the covariates of the training, validation, and test set and return a tensor for the
taining covariates and test covariates.
'''
cols_standardize = ["x0", "x1", "x2", "x3", "x8"]
cols_leave = ["x4", "x5", "x6", "x7"]

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)

x_train = x_mapper.fit_transform(df_train).astype("float32")
x_val = x_mapper.transform(df_val).astype("float32")
x_test = x_mapper.transform(df_test).astype("float32")

return torch.from_numpy(x_train), torch.from_numpy(x_val), torch.from_numpy(x_test)

def _get_target(cls, df : pd.DataFrame) -> np.ndarray:
'''
Takes pandas datframe and converts the duration, event targets into np.arrays
'''
duration = df['duration'].to_numpy().reshape(len(df['duration']),1)
event = df['event'].to_numpy().reshape(len(df['event']),1)

return duration, event

# Survival model class

class SurvModel(pl.LightningModule):
'''
Defines model, optimizers, forward step, and training step.
Define validation step as def validation_step if needed
Configured to use CoxPH loss from coxph.CoxPHLoss()
'''

def __init__(self, lr, in_features, out_features):
super().__init__()

self.save_hyperparameters()
self.lr = lr
self.in_features = in_features
self.out_features = out_features

# Define Model Here (in this case MLP)
self.net = nn.Sequential(
nn.Linear(self.in_features, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
nn.Dropout(0.1),
nn.Linear(32, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
nn.Dropout(0.1),
nn.Linear(32, self.out_features),
)

# Define loss function:
self.loss_func = coxph.CoxPHLoss()

def forward(self, x):
batch_size, data = x.size()
x = self.net(x)
return x

# Training step and validation step usually defined, this dataset only had train + test so left out val.
def training_step(self, batch, batch_idx):
x, target = batch
output = self.forward(x)

# target variable contains duration and event as a concatenated tensor
loss = self.loss_func(output, target[:,0], target[:,1])

# progress bar logging metrics (add custom metric definitions later if useful?)
self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr = self.lr
)
return optimizer


def main():
# Load Lightning DataModule
dat = MetaBrick(num_workers=0, discretize=False)
dat.setup('fit') #allows for input / output features to be configured in the model

# Load Lightning Module
model = SurvModel(lr=1e-2, in_features=dat.in_dims, out_features=dat.out_dims)
trainer = pl.Trainer(gpus=0, num_sanity_val_steps=0, max_epochs=20, fast_dev_run=False)

# Train model
trainer.fit(model,dat)


# Test model
dat = MetaBrick(num_workers=0, discretize=False)
trainer.test(model,datamodule=dat)

# Load final model & freeze
print('Running in Evaluation Mode...')
model.freeze()

# Setup test data (prepared from lightning module)
dat.setup('test')

# Predict survival on testing dataset
output = model(dat.x_test)

cum_haz = coxph.compute_cumulative_baseline_hazards(output, durations=dat.y_test[:,0], events=dat.y_test[:,1])
surv = coxph.output2surv(output, cum_haz[0])

# The surv_df dataframe needs to be transposed to format: {rows}: duration, {cols}: each individual
surv_df = pd.DataFrame(surv.transpose(1,0).numpy())
print(surv_df)

ev = EvalSurv(surv_df, dat.df_test.duration.values, dat.df_test.event.values, censor_surv='km')
time_grid = np.linspace(dat.df_test.duration.values.min(), dat.df_test.duration.values.max(), 100)

print(f"Concordance: {ev.concordance_td()}")
print(f"Brier Score: {ev.integrated_brier_score(time_grid)}")

if __name__ == '__main__':
main()
4 changes: 1 addition & 3 deletions examples/lightning_logistic_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _preprocess_features(cls, df_train: pd.DataFrame, df_test: pd.DataFrame) ->

# Survival model class


class SurvModel(pl.LightningModule):
'''
Defines model, optimizers, forward step, and training step.
Expand Down Expand Up @@ -182,5 +181,4 @@ def main():
print(f"Concordance: {ev.concordance_td()}")

if __name__ == '__main__':
main()

main()
148 changes: 148 additions & 0 deletions examples/torch_cox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
'''
A minimal example of how to fit a LogisticHazard model with a vanilla torch training loop.
The point of this example is to make it simple to use the LogisticHazard models in other frameworks
that are not based on torchtuples.
'''
from typing import Tuple

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from pycox.datasets import metabric
from pycox.evaluation import EvalSurv
from pycox.models import coxph

# For preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
import torchtuples as tt

def get_metabrick_train_val_test() -> Tuple[pd.DataFrame]:
"""Get the METABRICK dataset split into a trainin dataframe and a testing dataframe."""
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)
return df_train, df_val, df_test


def preprocess_features(df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame) -> Tuple[torch.Tensor]:
"""Preprocess the covariates of the training and test set and return a tensor for the
taining covariates and test covariates.
"""
cols_standardize = ["x0", "x1", "x2", "x3", "x8"]
cols_leave = ["x4", "x5", "x6", "x7"]

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)

x_train = x_mapper.fit_transform(df_train).astype("float32")
x_val = x_mapper.transform(df_val).astype("float32")
x_test = x_mapper.transform(df_test).astype("float32")

return torch.from_numpy(x_train), torch.from_numpy(x_val), torch.from_numpy(x_test)


def make_mlp(in_features: int, out_features: int) -> nn.Module:
"""Make a simple torch net"""
net = nn.Sequential(
nn.Linear(in_features, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
nn.Dropout(0.1),
nn.Linear(32, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
nn.Dropout(0.1),
nn.Linear(32, out_features),
)
return net

def get_target(df : pd.DataFrame) -> np.ndarray:
'''
Takes pandas datframe and converts the duration / event targets into np.array
'''
duration = df['duration'].to_numpy().reshape(len(df['duration']),1)
event = df['event'].to_numpy().reshape(len(df['event']),1)

return duration, event

def main() -> None:
# Get the metabrick dataset split in a train and test set
#np.random.seed(1234)
#torch.manual_seed(123)
df_train, df_val, df_test = get_metabrick_train_val_test()

# Preprocess features
x_train, x_val, x_test = preprocess_features(df_train, df_val, df_test)
y_train = torch.from_numpy(np.concatenate(get_target(df_train), axis=1))
y_test = torch.from_numpy(np.concatenate(get_target(df_test), axis=1))

y_val = torch.from_numpy(np.array(get_target(df_val)))

#Probably have to change this to something like test_target?
durations_test, events_test = get_target(df_test)
val = x_val, y_val

# Make an MLP nerual network
in_features = x_train.shape[1]
out_features = 1
net = make_mlp(in_features, out_features)

batch_size = 256
epochs = 20

train_dataset = TensorDataset(x_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)

# if verbose:
# print('Durations and events in order:')
# print(y_train[:,0])
# print(y_train[:,1]

# Set optimizer and loss function (optimization criterion)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
loss_func = coxph.CoxPHLoss()
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(train_dataloader):
x, target = data
optimizer.zero_grad()
output = net(x)
loss = loss_func(output, target[:,0], target[:,1]) # need x, durations, events
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"epoch: {epoch} -- loss: {running_loss / i}")

# Predict survival for the test set
# Set net in evaluation mode and turn off gradients


net.eval()
with torch.set_grad_enabled(False):
output = net(x_test)
#Cumulative Hazards Calculation
cum_haz = coxph.compute_cumulative_baseline_hazards(output, durations=y_test[:,0], events=y_test[:,1])

surv = coxph.output2surv(output, cum_haz[0])
# The dataframe needs to be transposed to format: {rows}: duration, {cols}: each individual
surv_df = pd.DataFrame(surv.transpose(1,0).numpy())
print(surv_df)

# print the test set concordance index
ev = EvalSurv(surv_df, df_test.duration.values, df_test.event.values, censor_surv='km')
time_grid = np.linspace(df_test.duration.values.min(), df_test.duration.values.max(), 100)

print(f"Concordance: {ev.concordance_td()}")
print(f"Brier Score: {ev.integrated_brier_score(time_grid)}")


if __name__ == "__main__":
main()