-
Notifications
You must be signed in to change notification settings - Fork 401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip saving of direct serialization fields #445
Skip saving of direct serialization fields #445
Conversation
When the trainer is constructed, the passes in a max_duration, grad_accum, and precision. These values should take precedence over whatever was in a checkpoint, so there is no need to serialize them. Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. (Renamed `fields` to `attrs` to reflect that State is no longer a dataclass, since fields are a dataclass concept). Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not. Closes #441.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple questions otherwise LGTM
I agree on Ancillary question is -- would we actually want logic where the checkpoint is loaded, then any user-provided arguments through the init are overridden? Right now, from what I understand, they are silently ignored. |
I think grad accum is very hardware-specific ..e.g. start training on 3080s, then resume on 3090s, so I don't think that should be persisted in the checkpoint. Though I think precision is also hardware specific (e.g. some hardware has support for bf16; others support only fp16 or amp), I could go either way on it theoretically. However, since This change also makes it consistent with deepspeed, since we don't save the deepspeed config (including precision) in the checkpoint. It also ensures that all arguments in the trainer init (except for seed) are used. All other parameters are used either to construct the classes into which the state is loaded, or are not saved. |
Hmm OK -- might be good to get the original issue requester and user cc: @siriuslee and maybe @A-Jacobson to opine on the expected user experience here. |
The use case for saving |
Sounds good, if precision and/or grad_accumt should be persisted, then we need to update the trainer's init signature to be |
Created #451 to track this. |
When the Trainer is constructed, the user passes in a max_duration, grad_accum, and precision in `Trainer.__init__`. These values should take precedence over whatever was in a checkpoint, so these values should not be saved in the checkpoint anyways. Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. No need to list which fields are not getting serialized. Also renamed `fields` to `attrs` to reflect that State is no longer a dataclass, since fields are a dataclass concept. Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not. Closes #441.
When the Trainer is constructed, the user passes in a max_duration, grad_accum, and precision in `Trainer.__init__`. These values should take precedence over whatever was in a checkpoint, so these values should not be saved in the checkpoint anyways. Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. No need to list which fields are not getting serialized. Also renamed `fields` to `attrs` to reflect that State is no longer a dataclass, since fields are a dataclass concept. Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not. Closes mosaicml#441.
When the Trainer is constructed, the user passes in a max_duration, grad_accum, and precision in
Trainer.__init__
. These values should take precedence over whatever was in a checkpoint, so these values should not be saved in the checkpoint anyways.Since these were the only direct serialization attributes, also cleaned up the state logic to just require that the attributes to be serialized are listed. No need to list which fields are not getting serialized. Also renamed
fields
toattrs
to reflect that State is no longer a dataclass, since fields are a dataclass concept.Also cleaned up the state serialization by refactoring the deepspeed logic out of the state and into the trainer next to where deepspeed is initialized. The state should not need to know whether the model is deepspeed or not.
Closes #441.