Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip saving of direct serialization fields #445

Merged
merged 4 commits into from
Feb 10, 2022

Conversation

ravi-mosaicml
Copy link
Contributor

@ravi-mosaicml ravi-mosaicml commented Feb 9, 2022

When the Trainer is constructed, the user passes in a max_duration, grad_accum, and precision in Trainer.__init__. These values should take precedence over whatever was in a checkpoint, so these values should not be saved in the checkpoint anyways.

Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. No need to list which fields are not getting serialized. Also renamed fields to attrs to reflect that State is no longer a dataclass, since fields are a dataclass concept.

Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not.

Closes #441.

When the trainer is constructed, the passes in a max_duration, grad_accum, and precision. These values should take precedence over whatever was in a checkpoint, so there is no need to serialize them.

Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. (Renamed `fields` to `attrs` to reflect that State is no longer a dataclass, since fields are a dataclass concept).

Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not.

Closes #441.
Copy link
Member

@ajaysaini725 ajaysaini725 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple questions otherwise LGTM

composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Show resolved Hide resolved
@hanlint
Copy link
Contributor

hanlint commented Feb 10, 2022

I agree on max_duration since it's a required input to the Trainer, but less convinced about not saving grad_accum and precision. Those seem like ones we would want to persist in the checkpoint? If someone trains in AMP, saves the checkpoint, then loads it up again without specify the Precision, then we would default to trainign on FP32?

Ancillary question is -- would we actually want logic where the checkpoint is loaded, then any user-provided arguments through the init are overridden? Right now, from what I understand, they are silently ignored.

@ravi-mosaicml
Copy link
Contributor Author

I think grad accum is very hardware-specific ..e.g. start training on 3080s, then resume on 3090s, so I don't think that should be persisted in the checkpoint.

Though I think precision is also hardware specific (e.g. some hardware has support for bf16; others support only fp16 or amp), I could go either way on it theoretically. However, since precision is passed in on init (perhaps implicitly, though a default), I think whatever is passed to the trainer should be used. This does not affect YAHP uses since that is serialized in the hparams. (https://github.com/mosaicml/composer/blob/8b625ad167f9838580e907c7b411b63f539080c3/examples/run_composer_trainer.py).

This change also makes it consistent with deepspeed, since we don't save the deepspeed config (including precision) in the checkpoint.

It also ensures that all arguments in the trainer init (except for seed) are used. All other parameters are used either to construct the classes into which the state is loaded, or are not saved.

@ravi-mosaicml ravi-mosaicml merged commit ea95116 into dev Feb 10, 2022
@ravi-mosaicml ravi-mosaicml deleted the ravi/skip_serialization_of_trainer_init_fields branch February 10, 2022 01:38
@hanlint
Copy link
Contributor

hanlint commented Feb 10, 2022

Hmm OK -- might be good to get the original issue requester and user cc: @siriuslee and maybe @A-Jacobson to opine on the expected user experience here.

@hanlint
Copy link
Contributor

hanlint commented Feb 10, 2022

The use case for saving precision especially s that users will want to know whether the precision that the checkpoint model was trained with -- has implications for what precision they should deploy the model in for inference.

@ravi-mosaicml
Copy link
Contributor Author

Sounds good, if precision and/or grad_accumt should be persisted, then we need to update the trainer's init signature to be grad_accum: Optional[int] = None where None implies to use the value in the checkpoint, or if no checkpoint, then use the current default. Can implement in a separate PR if desired.

@ravi-mosaicml
Copy link
Contributor Author

Created #451 to track this.

A-Jacobson pushed a commit that referenced this pull request Feb 10, 2022
When the Trainer is constructed, the user passes in a max_duration, grad_accum, and precision in `Trainer.__init__`. These values should take precedence over whatever was in a checkpoint, so these values should not be saved in the checkpoint anyways.

Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. No need to list which fields are not getting serialized. Also renamed `fields` to `attrs` to reflect that State is no longer a dataclass, since fields are a dataclass concept.

Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not.

Closes #441.
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this pull request Feb 23, 2022
When the Trainer is constructed, the user passes in a max_duration, grad_accum, and precision in `Trainer.__init__`. These values should take precedence over whatever was in a checkpoint, so these values should not be saved in the checkpoint anyways.

Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. No need to list which fields are not getting serialized. Also renamed `fields` to `attrs` to reflect that State is no longer a dataclass, since fields are a dataclass concept.

Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not.

Closes mosaicml#441.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loading a checkpoint overwrites max_duration
3 participants