Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch lightning example for MLP logistic hazard model #66

Merged
merged 2 commits into from
Feb 1, 2021

Conversation

rohanshad
Copy link

Traditionally the test_step function within the surv_model LightningModule class handles loading the data + calculating metrics for the test stage. Given the way the metrics are calculated I figured it's best to handle the test stage in vanilla pytorch.

Copy link
Owner

@havakv havakv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good to me! Great work!

I think it is fine to keep it minimal, but if you would prefer to add a validation stop and testing within pythoch-lighting. that would be fine too.

examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved
examples/lightning_logistic_hazard.py Outdated Show resolved Hide resolved

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to use DataFrameMapper if you prefer some other way of preprocessing data. I tend to use it, but it is generally not that common.

#Train model
trainer.fit(model,dat)

#Load model from best checkpoint & freeze
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually load from a checkpoint here? Isn't this just the final model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep you're right, that's just the final model. Fixed

@havakv
Copy link
Owner

havakv commented Jan 31, 2021

We should probably also change all indentation tabs to spaces to be consistent with the rest of pycox. I can do that for you if you want.

@rohanshad
Copy link
Author

Thanks for the review! I'll refactor + address the comments and commit shortly

@havakv havakv merged commit 7ed3e02 into havakv:refactor_out_torchtuples Feb 1, 2021
@rohanshad rohanshad deleted the lightning_example branch February 2, 2021 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants