Skip to content

Conversation

@bghira
Copy link
Owner

@bghira bghira commented Jan 29, 2026

This pull request improves support for RamTorch, where model weights may reside on the CPU and gradients on the GPU, in the custom AdamW optimizer for bfloat16. The changes ensure optimizer state tensors are created and migrated to the correct device, and parameter updates handle cross-device operations safely.

Device handling improvements for RamTorch compatibility:

  • When initializing optimizer state (exp_avg, exp_avg_sq, shift), tensors are now created on the same device as the gradient, supporting cases where weights are on CPU and gradients on GPU.
  • During each optimizer step, state tensors are migrated to the current gradient's device if needed, ensuring consistency when resuming from checkpoints or when device placement changes.
  • In the _make_step function, logic was added to handle parameter updates when the parameter and state tensors are on different devices: the shift tensor is moved to the parameter's device for the update, and truncation error is computed and moved back to the state device. [1] [2]

@bghira bghira merged commit 87216d8 into main Jan 29, 2026
2 checks passed
@bghira bghira deleted the bugfix/adamw-bf16-with-unsloth-and-ramtorch branch January 29, 2026 22:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants