-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat(train): add accelerate for multi gpu training #2154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.
…esses - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…gface/lerobot into feat/accelerate-melt-gpus
…gface/lerobot into feat/accelerate-melt-gpus
…gface/lerobot into feat/accelerate-melt-gpus
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First look at the Pr its done very well, great job!
I just have two comments, I'll give it a deeper dive tomorrow and test it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for me its LGTM, waiting for @imstevenpmwork approval
317a0bc to
0ef404c
Compare
This PR adds accelerate integration and docs to LeRobot. We keep it basic and not yet add all accelerate options (deep speed etc)
lerobot_train.pyand so we can utilize their methods for device discovery etc.Tested