Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch lightning example for MLP logistic hazard model #66

Merged
merged 2 commits into from
Feb 1, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
187 changes: 187 additions & 0 deletions examples/lightning_logistic_hazard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""A minimal example of how to fit a LogisticHazard model with pytorch lightning
The point of this example is to make it simple to use the LogisticHazard models in other frameworks
that are not based on torchtuples.

rohanshad marked this conversation as resolved.
Show resolved Hide resolved
"""
from typing import Tuple

import numpy as np
import pandas as pd

import torch

import pytorch_lightning as pl
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from pycox.datasets import metabric
from pycox.evaluation import EvalSurv
from pycox.models import logistic_hazard

# For preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

# Lightning Dataset Module
class metabrick(pl.LightningDataModule):
rohanshad marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, batch_size = 256, num_durations=10, num_workers=0, verbose = False):
rohanshad marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.batch_size = batch_size
self.num_durations = num_durations
self.verbose = verbose
self.num_workers = num_workers

def setup(self, stage=None):
'''
Get the METABRICK dataset split into a training dataframe and a testing dataframe.
Preprocesses features and targets (duration and event), discretize time into 'num_duration' equidistant points.
'''

#Load and split dataset into train and test (if there's train and val this can be called within stage == 'fit')
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)

self.x_train, self.x_test = self.__preprocess_features(df_train, df_test)
self.labtrans = logistic_hazard.LabTransDiscreteTime(self.num_durations)

if stage == 'fit' or stage is None:
#Pre-process features and targets
self.y_train = self.labtrans.fit_transform(df_train.duration.values, df_train.event.values)
self.y_train_duration = torch.from_numpy(self.y_train[0])
self.y_train_event = torch.from_numpy(self.y_train[1])

#Create training dataset
self.train_set = TensorDataset(self.x_train, self.y_train_duration, self.y_train_event)

#Input and output dimensions for building net
self.in_dims = self.x_train.shape[1]
self.out_dims = self.labtrans.out_features

if stage == 'test' or stage is None:
#Return test dataframe
self.df_test = df_test

def train_dataloader(self):
'''
Build training dataloader
num_workers set to 0 by default because of some thread issue
'''
train_loader = DataLoader(
dataset = self.train_set,
batch_size = self.batch_size,
shuffle = True,
num_workers = self.num_workers
)
return train_loader

@classmethod
def __preprocess_features(cls,df_train: pd.DataFrame, df_test: pd.DataFrame) -> Tuple[torch.Tensor]:
rohanshad marked this conversation as resolved.
Show resolved Hide resolved
'''
Preprocess the covariates of the training and test set and return a tensor for the
taining covariates and test covariates.
'''
cols_standardize = ["x0", "x1", "x2", "x3", "x8"]
cols_leave = ["x4", "x5", "x6", "x7"]

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to use DataFrameMapper if you prefer some other way of preprocessing data. I tend to use it, but it is generally not that common.


x_train = x_mapper.fit_transform(df_train).astype("float32")
x_test = x_mapper.transform(df_test).astype("float32")
return torch.from_numpy(x_train), torch.from_numpy(x_test)

# Survival model class
class surv_model(pl.LightningModule):
rohanshad marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, lr, in_features, out_features):
super().__init__()
rohanshad marked this conversation as resolved.
Show resolved Hide resolved
'''
Potentially allow for variable to specify appropriate loss function here?
ie: loss_func = logistic_hazard.NLLLogisticHazardLoss()
self.loss = loss_func
'''
rohanshad marked this conversation as resolved.
Show resolved Hide resolved
self.save_hyperparameters()
self.lr = lr
self.in_features = in_features
self.out_features = out_features

#Define Model Here (in this case MLP)
self.net = nn.Sequential(
nn.Linear(self.in_features, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
nn.Dropout(0.1),
nn.Linear(32, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
nn.Dropout(0.1),
nn.Linear(32, self.out_features),
)

# Define loss function:
self.loss_func = logistic_hazard.NLLLogistiHazardLoss()

def forward(self, x):
batch_size, data = x.size()
x = self.net(x)
return x

#Training step and validation step usually defined, this dataset only had train + test so left out val.
def training_step(self, batch, batch_idx):
x, duration, event = batch
output = self.forward(x)
loss = self.loss_func(output,duration,event)
rohanshad marked this conversation as resolved.
Show resolved Hide resolved

# progress bar logging metrics (add custom metric definitions later if useful?)
self.log('loss', loss, on_step = True, on_epoch=True, prog_bar = True)
return loss

# def test_step(self, batch, batch_idx):
# x, duration, event = batch
# output = self.forward(x)
# surv = logistic_hazard.output2surv(output)

# return surv
# # surv_df = pd.DataFrame(surv.numpy().transpose(), labtrans.cuts)
# # ev = EvalSurv(surv_df, duration.numpy().transpose(), event.numpy().transpose())
# # print(ev)
rohanshad marked this conversation as resolved.
Show resolved Hide resolved

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr = self.lr
)
return optimizer

def main():
#Load Lightning DataModule
dat = metabrick(num_workers=0)
dat.setup('fit') #allows for input / output features to be configured in the model

#Load Lightning Module
model = surv_model(lr=1e-3, in_features=dat.in_dims, out_features=dat.out_dims)
trainer = pl.Trainer(gpus = 0, num_sanity_val_steps = 0, max_epochs = 20, fast_dev_run = False)
rohanshad marked this conversation as resolved.
Show resolved Hide resolved

#Train model
trainer.fit(model,dat)

#Load model from best checkpoint & freeze
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually load from a checkpoint here? Isn't this just the final model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep you're right, that's just the final model. Fixed

print('Running in Evaluation Mode...')
model.freeze()

#Setup test data (prepared from lightning module)
dat.setup('test')

#Predict survival on testing dataset
output = model(dat.x_test)
surv = logistic_hazard.output2surv(output)
surv_df = pd.DataFrame(surv.numpy().transpose(), dat.labtrans.cuts)
ev = EvalSurv(surv_df, dat.df_test.duration.values, dat.df_test.event.values)

#Print evaluation metrics
print(f"Concordance: {ev.concordance_td()}")

if __name__ == '__main__':
main()