-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch lightning example for MLP logistic hazard model #66
Torch lightning example for MLP logistic hazard model #66
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good to me! Great work!
I think it is fine to keep it minimal, but if you would prefer to add a validation stop and testing within pythoch-lighting. that would be fine too.
|
||
standardize = [([col], StandardScaler()) for col in cols_standardize] | ||
leave = [(col, None) for col in cols_leave] | ||
x_mapper = DataFrameMapper(standardize + leave) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't have to use DataFrameMapper if you prefer some other way of preprocessing data. I tend to use it, but it is generally not that common.
#Train model | ||
trainer.fit(model,dat) | ||
|
||
#Load model from best checkpoint & freeze |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you actually load from a checkpoint here? Isn't this just the final model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep you're right, that's just the final model. Fixed
We should probably also change all indentation tabs to spaces to be consistent with the rest of pycox. I can do that for you if you want. |
Thanks for the review! I'll refactor + address the comments and commit shortly |
…equirements.txt
Traditionally the
test_step
function within thesurv_model
LightningModule class handles loading the data + calculating metrics for the test stage. Given the way the metrics are calculated I figured it's best to handle the test stage in vanilla pytorch.