Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thoughts about integration with pytorch-lightning? #60

Open
rohanshad opened this issue Jan 12, 2021 · 12 comments
Open

Thoughts about integration with pytorch-lightning? #60

rohanshad opened this issue Jan 12, 2021 · 12 comments

Comments

@rohanshad
Copy link

Great work here havakv, looks really modular and well thought out.

I'm interested in using some of the time-to-event tools here with high dimensional imaging inputs. I've been building up a medical imaging codebase using pytorch-lightning for a little bit mostly because of how modular & convenient it makes iterating over multiple experiments on a cluster environment.

Do you have any ideas of how best to re-organize some of pycox.models (PCHazard for example) to torch-lightning before I start on this? What I'm really interested in is being able to use the pytorch lightning trainer.

@havakv
Copy link
Owner

havakv commented Jan 12, 2021

Thank you for the kind words!
I think pycox is way to dependent on torchtuples (which is quite limited) and this have been discussed in #25.
Pycox could really benefit from working pytoch-lightning, but I don't think there needs to be any special integration, just a way to decouple pycox from torchtuples. One can then give examples of how to fit models with pytorch-lightning.

In relation to #25 I made this example in this branch where I propose a change to the LogisticHazard model so that it can can be fitted with just vanilla pytorch (no torchtuples stuff). Could you take a look at that and see if you're able to make it work with pytorch-lightning?

It would really help with some feedback on these changes before I start refactoring all the models :)

@rohanshad
Copy link
Author

Perfect!

I'll take a crack at this soon (hopefully within this week) and circle back. Thanks 👍

@rohanshad
Copy link
Author

rohanshad commented Jan 25, 2021

Got it all working on a new conda environment, and I've successfully ported the example to torch lightning. I set up the dataset within a lightning DataModule and packaged the pre-processing functions there too. Since each model may require slightly different pre-processing steps, it might make sense to define all those preprocessing functions within the dataset module itself. I setup the data for train / test here too.

The model, train logic, metrics, loss, optimizer all fit in a surv_model LightningModule. The trainer function trains it all and spits out a progress bar, tracks experiment versions, and can dump logs to csv / tensorboard as required:

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type                 | Params
---------------------------------------------------
0 | net       | Sequential           | 1 K   
1 | loss_func | NLLLogistiHazardLoss | 0     
Epoch 19: 100%|███████| 6/6 [00:01<00:00,  4.23it/s, loss=2.269, v_num=34, loss_step=2.14, loss_epoch=2.26]
Running in Evaluation Mode...
Concordance: 0.6252826147316648

The only think that I keep vanilla pytorch is the testing phase since the metrics are calculated directly on a pandas dataframe obviating the need for a DataLoader. Let me know if you'd want me to open a PR on that branch so you can see what this looks like.

@havakv
Copy link
Owner

havakv commented Jan 26, 2021

Great work @rohanshad!
Sure you can open a PR! I'ts much simpler do discuss when we have some concrete examples.

@rohanshad
Copy link
Author

#66 Here you go ^

@rohanshad
Copy link
Author

Let me know if you'd like to create an example for CoxPH, the workings and estimators seem to be a bit different from the logistic_hazards models. I can carry on from there and attempt to make a flexible-ish lightning module that works with coxPH too.

@havakv
Copy link
Owner

havakv commented Feb 4, 2021

I think the Cox models will be a bit harder to make work (though CoxPH is likely the simplest). Currently computations of the non-parametric baseline hazards are part of the CoxPH class. Probably need to factor that out in a similar way as I did for the logistic-hazard. But you are of course more than welcome to give it a go!

Right now I have too much to do between work and revisjons, so I cant prioritise this (I imagine it's quite some work), but I'll get started as soon as I get the time.

@havakv havakv reopened this Feb 6, 2021
@havakv
Copy link
Owner

havakv commented Feb 6, 2021

I guess this can stay open until all of pycox can be use with pytroch-lightning

@havakv
Copy link
Owner

havakv commented Feb 6, 2021

@rohanshad If you wan't to take a crack at CoxPH in pytorch-lighting, I've now made a refactored version of CoxPH in https://github.com/havakv/pycox/blob/refactor_out_torchtuples/pycox/models/coxph.py that should be straight forward to use. There's missing some docs and tests, but I'll add that later.

By using compute_cumulative_baseline_hazards and output2surv, I think it shouldn't be too much work. Let me know if you run into any issues!

@yorickvanzweeden
Copy link

yorickvanzweeden commented May 25, 2021

@havakv Thank you for the refactorings of CoxPH and LogisticHazard. Do you have any plans to refactor PC-Hazard? Or should I be able to use them like the other two?

I am currently doing this. Yet, compared with the LogisticHazard and CoxPH, I am not getting great performance.

import pytorch_lightning as pl
import pandas as pd
import torch
import torch.nn.functional as F

from pycox.models.loss import nll_pc_hazard_loss
from pycox.evaluation import EvalSurv
from pycox.models.utils import pad_col, make_subgrid

class DummyModel(pl.LightningModule):
    def __init__(self,  duration_index=None):

        super().__init__()

        self.net = SomeModel
        self.loss_func = nll_pc_hazard_loss
        self.duration_index = duration_index
        
    def forward(self, x):
        return self.net.forward(x)

    def common_step(self, batch, batch_idx, stage):
        x, duration, event, interval = batch
        preds = self(x)
        loss = self.loss_func(preds, duration, event, interval)

        if stage == "train":
            return {"loss": loss}
        else:
            return {"loss": loss, "preds": preds, "event": event, "duration": duration}

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'train')

    def training_epoch_end(self, outs):
        self.logger.experiment.add_scalar("loss/train", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)

    def validation_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'val')

    def validation_epoch_end(self, outs):
          self.logger.experiment.add_scalar("loss/val", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)

          predictions = torch.vstack([x['preds'] for x in outs])
          durations = torch.vstack([x['duration'] for x in outs])
          events = torch.vstack([x['event'] for x in outs])

          surv_df = self.predict_surv_df(predictions, sub=10, duration_index=self.duration_index)
          ev = EvalSurv(surv_df, durations.cpu().numpy().reshape(-1, ), events.cpu().numpy().reshape(-1, ))
          self.logger.experiment.add_scalar("val_auroc", ev.concordance_td(), current_epoch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=SomeLearningRate)
        return optimizer

    def predict_surv_df(self, preds, sub, duration_index):
        n = preds.shape[0]
        hazard = F.softplus(preds).view(-1, 1).repeat(1, sub).view(n, -1).div(sub)
        hazard = pad_col(hazard, where='start')
        surv = hazard.cumsum(1).mul(-1).exp()
        surv = surv.cpu().numpy()

        index = None
        if duration_index is not None:
            index = make_subgrid(duration_index, sub)
        return pd.DataFrame(surv.transpose(), index)

The DataLoader is aligned with the PC-Hazard notebook. In this way, the duration_index corresponds to PCHazard.label_transform(num_durations).cuts

@havakv
Copy link
Owner

havakv commented May 30, 2021

Hi @yorickvanzweeden.
I should refactor PCHazard in the same way as for LogisticHazard (just struggling to find the time).

From what I can see, your code should work, so I don't really know why you're not getting the results you want. Have you tried comparing these results with

model = PCHazard(net, optimizer, duration_index=labtrans.cuts)
model.fit(...)
mode.predict_surv_df(...)

to check if it is just the PCHazard that doesn't perform well, or if there is something with you implementation?

@yorickvanzweeden
Copy link

Thanks @havakv for your reply. I suspect it is due to the difficulty of the problem in combination with hyperparameters that have yet to be optimized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants