- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6.5k
 
Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading #7685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice work! Thanks for providing the reasoning too.
Could we also add a test for this feature?
          
 Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?  | 
    
| 
           The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.  | 
    
| 
           can you check whether the decay works correctly for you in a small test loop? for me, it does not. and i have to do:                     ema_unet.optimization_step = global_step(this is a problem in the prev code too)  | 
    
          
 I do have a basic example for it in the SD1.5 text to image training script right now.  I think that exposing   | 
    
| 
           So early testing on this version is showing that the  The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.  | 
    
          
 I would add that as a CLI argument but not set it to True by default. We should make a note about this feature in our docs and let the users tinker around with it first. If there's sufficient reception, we can definitely change to True as its default. 
 Yeah would love to know the differences. I think not using   | 
    
          
 This appears to have been correct. Accelerate's dataloader uses blocking transfers, where my other training script had pinned memory and exclusively non-blocking transfers. Submitted a patch that should fix that and make Accelerate's dataloader perform better in these situations: huggingface/accelerate#2685 While it is faster now, there is one remaining apparent performance issue where there are a bunch of  I strongly suspect this is a WSL issue, when our other training machine is free I will see if this issue happens there as well. Regardless of this issue though, in its current state this is faster than a blocking transfer would be.  | 
    
| 
           After testing/profiling on a different machine I'm fairly confident that the non-blocking offload is working as well as a given environment/other parts of the code can permit it to work.  Most significantly dataloader data transfers with Accelerate (patch submitted already) and the  I'll add foreach as a commandline arg later, and then will mark as ready, then can work on docs and possibly propagating changes to other example scripts.  | 
    
| 
           Thank you for investigating. We can just with one example and then open it up to the community to follow your PR as a reference. 
 How significant is this one?  | 
    
          
 They're not a huge issue if it is being called at most once per gradient update, because it necessarily must complete the steps to actually materialize the loss value.  I don't have profiling data since I knew it would be a problem and commented them out immediately when testing, but the main offenders would be the tqdm postfix update and the updating of  Mitigating it without losing too much functionality would look like: 
 In my experiences, this causes nearly no perceptible overhead. You won't see the loss for every forward pass, but the loss for the gradient update is arguably what you actually want to be tracking anyways. I could set up a PR for this if interested.  | 
    
| 
           Oh thanks so much. If this is not causing significant overheads, I would like to keep it as is for now.  | 
    
| 
           @sayakpaul is this good to merge?  | 
    
| 
           I assume not. But I will defer to @drhead for confirming.  | 
    
          
 It's effectively complete except for linting  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is good to merge. We'd need to:
- Fix the code linting issues.
 - Add a block about this feature in the README of 
train_text_to_image.py. This is quite beneficial. 
Additionally, do you think we could add a test for this here: https://github.com/huggingface/diffusers/blob/main/tests/others/test_ema.py?
          
 I have implemented the first two -- which tests are you interested in having?  I would think running (nearly?) all of the tests on the   | 
    
| 
           Thanks so much. Sorry for not being clear about the tests. 
 That sounds good to me. We could have a separate test module,   | 
    
| [param.data for param in parameters], | ||
| [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one question. Should we add non_blocking=True here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's much benefit to doing so. I expect that copy_to is going to be used only for validation and model saving and won't be used every step, so in my opinion there's little benefit to having a non-blocking transfer for something that'll probably be used at the absolute most every several minutes and more realistically hours apart. I've also had to troubleshoot a few issues with unexpected increased VRAM usage (presumably from tensors not being removed from memory fast enough) when switching between validation and training, so with that (combined with the risk that someone might do something like a non-blocking copy_to of an EMA state to the model's parameters and save the model, which might result in save being called on non-ready tensors) I think it is safer for these to just be blocking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright. That sounds reasonable to me.
| 
           @drhead a gentle ping here :)  | 
    
| 
           I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.  | 
    
          
 I admittedly haven't done anything with test suites in a very long time -- do I need to do anything other than merely adding the class like I've done?  | 
    
| 
           All good. Will merge once the CI is green. Thanks a bunch for your contributions!  | 
    
| 
           Seems like there is a test failure: https://github.com/huggingface/diffusers/actions/runs/9622457113/job/26543763957?pr=7685#step:10:361  | 
    
| 
           
  | 
    
           | 
    
| 
           i've been testing a variant of this. 
  | 
    
| 
           Perhaps these and the mps fix could be clubbed in a separate PR?  | 
    
| 
           well, great work on the conceptual design of this change. on a macbook pro m3 max it improves iteration times such that EMA produces no visible impact on training speed, even if we never move it to the GPU.  | 
    
| 
           Thanks a lot for working on this feature and for iterating on it.  | 
    
…s and better support for non-blocking CPU offloading (#7685) * Add support for _foreach operations and non-blocking to EMAModel * default foreach to false * add non-blocking EMA offloading to SD1.5 T2I example script * fix whitespace * move foreach to cli argument * linting * Update README.md re: EMA weight training * correct args.foreach_ema * add tests for foreach ema * code quality * add foreach to from_pretrained * default foreach false * fix linting --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: drhead <a@a.a>

What does this PR do?
This adds a few extra things to EMAModel intended to reduce its overhead and allow overlapping compute and transfers better.
First, I added a new option for
foreachthat, if set, will usetorch._foreachfunctions for performing parameter updates and in-place copies. This should reduce kernel launch overhead on these operations by a fair amount. It can increase the peak memory usage of these operations, so it is disabled by default.Second, I added a function for pinning memory for shadow parameters alongside an option to pass through the
non_blockingparameter for theEMAModel.to()function (defaults to False). When used together, this should allow users to easily process EMA updates asynchronously while offloading the parameters to the CPU. Using this, it should be possible to handle EMA updates just as fast as if they lived on the GPU as long as it wasn't already taking longer than the entire training step with regular synchronous transfers.I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload. I would also greatly appreciate someone verifying that this works on deepspeed as I don't have access to suitable hardware to test it on a multi-device setup. I will implement an example for the SD training script soon.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Training examples: @sayakpaul