adamw_bf16 compatibility with unsloth checkpointing #2530
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request improves support for RamTorch, where model weights may reside on the CPU and gradients on the GPU, in the custom AdamW optimizer for bfloat16. The changes ensure optimizer state tensors are created and migrated to the correct device, and parameter updates handle cross-device operations safely.
Device handling improvements for RamTorch compatibility:
exp_avg,exp_avg_sq,shift), tensors are now created on the same device as the gradient, supporting cases where weights are on CPU and gradients on GPU._make_stepfunction, logic was added to handle parameter updates when the parameter and state tensors are on different devices: theshifttensor is moved to the parameter's device for the update, and truncation error is computed and moved back to the state device. [1] [2]