-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use state_dict Torchmetrics Serialization #2116
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This adds a backwards compatibility test -- very nice! But is this change backwards compatible? The checkpoint appears to be off dev and not 0.13.5
Also, please show screenshot of test manually passing as it is only run on daily (since its remote)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging through this @nik-mosaic! Left a few comments and questions, mostly around the usage of private attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we not need the backwards compatibiltiy hacks with this? because wont they still be the old format?
I have added a comment to address this directly in the code: The explanation is as follows. Given the rest of a Composer checkpoint, the state_dict() and _computed attributes of a Torchmetrics object are enough information to recreate it upon serialization. We only serialize the minimum metric information to maximize backwards compatibility --- old checkpoints will continue to be compatible even if other Torchmetrics attributes have changed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for taking this on!
Serialize and load torchmetrics through state_dict() and load_state_dict() instead of pickle
Serialize and load torchmetrics through state_dict() and load_state_dict() instead of pickle
What does this PR do?
Our metrics objects were previously serialized using pickle. The pickling includes many fields which are unnecessary and may change from version to version, causing a mismatch when we stop and start training with slightly different configurations or upgrade Torchmetrics.
This PR changes the saving and loading of Composer state objects so that state.train_metrics and state.eval_metrics use Torchmetrics' built in
state_dict()
method instead of Pickle. Thestate_dict()
method essentially returns a dictionary of<metric_name, metric_value>
<key, value> pairs, without any of the other data.When loading a state dict from its serialized version, we recreate the metric as follows:.
(1) Get the state.model
(2) Call model.get_metrics(), which creates a default version of the metrics
(3) Loop over the dictionary of
<metric_name, metric_value>
pairs from the serialization, matching metric names to the metrics from model.get_metrics() and populating values with the savedmetric_value
. If a saved metric isn't present in the model, we skip it.Other changes in this PR:
_ensure_metrics_device_and_dtype
, which is only necessary for Deepspeed models, behind an is_model_deepspeed() check. This should fix CO-1910.Testing:
test_load_remote_checkpoint
, a backwards compatibility checkpoint test. I have uploaded a checkpoint saved with Composer 0.13.5 and default dependenices (torchmetrics 0.11.3, etc.). This test downloads the checkpoint and ensures equivalency with a currently trained version. As we continue to push to Composer, this test will become more useful because our remote checkpoint will remain frozen while our local one will include all future trainer/model/metrics changes.What issue(s) does this change relate to?
CO-1918, CO-1907, CO-1853.