-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thoughts about integration with pytorch-lightning? #60
Comments
Thank you for the kind words! In relation to #25 I made this example in this branch where I propose a change to the LogisticHazard model so that it can can be fitted with just vanilla pytorch (no torchtuples stuff). Could you take a look at that and see if you're able to make it work with pytorch-lightning? It would really help with some feedback on these changes before I start refactoring all the models :) |
Perfect! I'll take a crack at this soon (hopefully within this week) and circle back. Thanks 👍 |
Got it all working on a new conda environment, and I've successfully ported the example to torch lightning. I set up the dataset within a lightning DataModule and packaged the pre-processing functions there too. Since each model may require slightly different pre-processing steps, it might make sense to define all those preprocessing functions within the dataset module itself. I setup the data for train / test here too. The model, train logic, metrics, loss, optimizer all fit in a surv_model LightningModule. The trainer function trains it all and spits out a progress bar, tracks experiment versions, and can dump logs to csv / tensorboard as required:
The only think that I keep vanilla pytorch is the testing phase since the metrics are calculated directly on a pandas dataframe obviating the need for a DataLoader. Let me know if you'd want me to open a PR on that branch so you can see what this looks like. |
Great work @rohanshad! |
#66 Here you go ^ |
Let me know if you'd like to create an example for CoxPH, the workings and estimators seem to be a bit different from the logistic_hazards models. I can carry on from there and attempt to make a flexible-ish lightning module that works with coxPH too. |
I think the Cox models will be a bit harder to make work (though CoxPH is likely the simplest). Currently computations of the non-parametric baseline hazards are part of the CoxPH class. Probably need to factor that out in a similar way as I did for the logistic-hazard. But you are of course more than welcome to give it a go! Right now I have too much to do between work and revisjons, so I cant prioritise this (I imagine it's quite some work), but I'll get started as soon as I get the time. |
I guess this can stay open until all of pycox can be use with pytroch-lightning |
@rohanshad If you wan't to take a crack at CoxPH in pytorch-lighting, I've now made a refactored version of CoxPH in https://github.com/havakv/pycox/blob/refactor_out_torchtuples/pycox/models/coxph.py that should be straight forward to use. There's missing some docs and tests, but I'll add that later. By using compute_cumulative_baseline_hazards and output2surv, I think it shouldn't be too much work. Let me know if you run into any issues! |
@havakv Thank you for the refactorings of CoxPH and LogisticHazard. Do you have any plans to refactor PC-Hazard? Or should I be able to use them like the other two? I am currently doing this. Yet, compared with the LogisticHazard and CoxPH, I am not getting great performance. import pytorch_lightning as pl
import pandas as pd
import torch
import torch.nn.functional as F
from pycox.models.loss import nll_pc_hazard_loss
from pycox.evaluation import EvalSurv
from pycox.models.utils import pad_col, make_subgrid
class DummyModel(pl.LightningModule):
def __init__(self, duration_index=None):
super().__init__()
self.net = SomeModel
self.loss_func = nll_pc_hazard_loss
self.duration_index = duration_index
def forward(self, x):
return self.net.forward(x)
def common_step(self, batch, batch_idx, stage):
x, duration, event, interval = batch
preds = self(x)
loss = self.loss_func(preds, duration, event, interval)
if stage == "train":
return {"loss": loss}
else:
return {"loss": loss, "preds": preds, "event": event, "duration": duration}
def training_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, 'train')
def training_epoch_end(self, outs):
self.logger.experiment.add_scalar("loss/train", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)
def validation_step(self, batch, batch_idx):
return self.common_step(batch, batch_idx, 'val')
def validation_epoch_end(self, outs):
self.logger.experiment.add_scalar("loss/val", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)
predictions = torch.vstack([x['preds'] for x in outs])
durations = torch.vstack([x['duration'] for x in outs])
events = torch.vstack([x['event'] for x in outs])
surv_df = self.predict_surv_df(predictions, sub=10, duration_index=self.duration_index)
ev = EvalSurv(surv_df, durations.cpu().numpy().reshape(-1, ), events.cpu().numpy().reshape(-1, ))
self.logger.experiment.add_scalar("val_auroc", ev.concordance_td(), current_epoch)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=SomeLearningRate)
return optimizer
def predict_surv_df(self, preds, sub, duration_index):
n = preds.shape[0]
hazard = F.softplus(preds).view(-1, 1).repeat(1, sub).view(n, -1).div(sub)
hazard = pad_col(hazard, where='start')
surv = hazard.cumsum(1).mul(-1).exp()
surv = surv.cpu().numpy()
index = None
if duration_index is not None:
index = make_subgrid(duration_index, sub)
return pd.DataFrame(surv.transpose(), index) The DataLoader is aligned with the PC-Hazard notebook. In this way, the duration_index corresponds to |
Hi @yorickvanzweeden. From what I can see, your code should work, so I don't really know why you're not getting the results you want. Have you tried comparing these results with model = PCHazard(net, optimizer, duration_index=labtrans.cuts)
model.fit(...)
mode.predict_surv_df(...) to check if it is just the PCHazard that doesn't perform well, or if there is something with you implementation? |
Thanks @havakv for your reply. I suspect it is due to the difficulty of the problem in combination with hyperparameters that have yet to be optimized. |
Great work here havakv, looks really modular and well thought out.
I'm interested in using some of the time-to-event tools here with high dimensional imaging inputs. I've been building up a medical imaging codebase using pytorch-lightning for a little bit mostly because of how modular & convenient it makes iterating over multiple experiments on a cluster environment.
Do you have any ideas of how best to re-organize some of pycox.models (PCHazard for example) to torch-lightning before I start on this? What I'm really interested in is being able to use the pytorch lightning trainer.
The text was updated successfully, but these errors were encountered: