Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use state_dict Torchmetrics Serialization #2116

Merged
merged 54 commits into from
Apr 27, 2023

Conversation

nik-mosaic
Copy link
Contributor

@nik-mosaic nik-mosaic commented Apr 4, 2023

What does this PR do?

Our metrics objects were previously serialized using pickle. The pickling includes many fields which are unnecessary and may change from version to version, causing a mismatch when we stop and start training with slightly different configurations or upgrade Torchmetrics.

This PR changes the saving and loading of Composer state objects so that state.train_metrics and state.eval_metrics use Torchmetrics' built in state_dict() method instead of Pickle. The state_dict() method essentially returns a dictionary of <metric_name, metric_value> <key, value> pairs, without any of the other data.

When loading a state dict from its serialized version, we recreate the metric as follows:.
(1) Get the state.model
(2) Call model.get_metrics(), which creates a default version of the metrics
(3) Loop over the dictionary of <metric_name, metric_value> pairs from the serialization, matching metric names to the metrics from model.get_metrics() and populating values with the saved metric_value. If a saved metric isn't present in the model, we skip it.

Other changes in this PR:

  • Gate _ensure_metrics_device_and_dtype, which is only necessary for Deepspeed models, behind an is_model_deepspeed() check. This should fix CO-1910.

Testing:

  • We have modified the test_checkpoint.py and test_sharded_checkpoint.py files to include a metric equality check wherever we have a weight equality or an optimizer equality check.
  • Added test_load_remote_checkpoint, a backwards compatibility checkpoint test. I have uploaded a checkpoint saved with Composer 0.13.5 and default dependenices (torchmetrics 0.11.3, etc.). This test downloads the checkpoint and ensures equivalency with a currently trained version. As we continue to push to Composer, this test will become more useful because our remote checkpoint will remain frozen while our local one will include all future trainer/model/metrics changes.
  • Locally, all tests pass on torchmetrics 0.11.4, but we do not bump the version in this PR.
  • @coryMosaicML has verified that CO-1910 has been fixed with this PR.

What issue(s) does this change relate to?

CO-1918, CO-1907, CO-1853.

@nik-mosaic nik-mosaic marked this pull request as ready for review April 25, 2023 15:03
@mvpatel2000 mvpatel2000 requested a review from bcui19 April 25, 2023 20:39
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a backwards compatibility test -- very nice! But is this change backwards compatible? The checkpoint appears to be off dev and not 0.13.5

Also, please show screenshot of test manually passing as it is only run on daily (since its remote)

tests/trainer/test_checkpoint.py Outdated Show resolved Hide resolved
@mvpatel2000 mvpatel2000 requested a review from eracah April 25, 2023 20:43
Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging through this @nik-mosaic! Left a few comments and questions, mostly around the usage of private attributes.

composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Show resolved Hide resolved
composer/core/state.py Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Show resolved Hide resolved
composer/trainer/trainer.py Show resolved Hide resolved
tests/trainer/test_checkpoint.py Outdated Show resolved Hide resolved
@nik-mosaic
Copy link
Contributor Author

nik-mosaic commented Apr 27, 2023

Here it is passing pytest tests/trainer/test_checkpoint.py -m remote. The remote checkpoint file has now changed, it is now one created off the Composer dev branch. To get this test to pass, we add a case to the state.py serialization/deserialization section to support deserializing a Torchmetrics object. Old composer checkpoints will have Torchmetrics objects, new ones will have dictionaries with metric tensors.
image

Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we not need the backwards compatibiltiy hacks with this? because wont they still be the old format?

composer/core/state.py Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
@nik-mosaic
Copy link
Contributor Author

nik-mosaic commented Apr 27, 2023

Why do we not need the backwards compatibility hacks with this? because wont they still be the old format?

I have added a comment to address this directly in the code: The explanation is as follows.

Given the rest of a Composer checkpoint, the state_dict() and _computed attributes of a Torchmetrics object are enough information to recreate it upon serialization. We only serialize the minimum metric information to maximize backwards compatibility --- old checkpoints will continue to be compatible even if other Torchmetrics attributes have changed.

Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for taking this on!

@nik-mosaic nik-mosaic merged commit 389255b into mosaicml:dev Apr 27, 2023
@nik-mosaic nik-mosaic deleted the nik/torchmetrics branch April 27, 2023 04:07
dakinggg pushed a commit that referenced this pull request Apr 27, 2023
Serialize and load torchmetrics through state_dict() and load_state_dict() instead of pickle
dakinggg pushed a commit that referenced this pull request Apr 27, 2023
Serialize and load torchmetrics through state_dict() and load_state_dict() instead of pickle
@mvpatel2000 mvpatel2000 mentioned this pull request May 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants