Skip to content

Conversation

@drhead
Copy link
Contributor

@drhead drhead commented Apr 16, 2024

What does this PR do?

This adds a few extra things to EMAModel intended to reduce its overhead and allow overlapping compute and transfers better.

First, I added a new option for foreach that, if set, will use torch._foreach functions for performing parameter updates and in-place copies. This should reduce kernel launch overhead on these operations by a fair amount. It can increase the peak memory usage of these operations, so it is disabled by default.

Second, I added a function for pinning memory for shadow parameters alongside an option to pass through the non_blocking parameter for the EMAModel.to() function (defaults to False). When used together, this should allow users to easily process EMA updates asynchronously while offloading the parameters to the CPU. Using this, it should be possible to handle EMA updates just as fast as if they lived on the GPU as long as it wasn't already taking longer than the entire training step with regular synchronous transfers.

I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload. I would also greatly appreciate someone verifying that this works on deepspeed as I don't have access to suitable hardware to test it on a multi-device setup. I will implement an example for the SD training script soon.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Training examples: @sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work! Thanks for providing the reasoning too.

Could we also add a test for this feature?

@sayakpaul
Copy link
Member

I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload.

Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira
Copy link
Contributor

bghira commented Apr 16, 2024

can you check whether the decay works correctly for you in a small test loop? for me, it does not. and i have to do:

                    ema_unet.optimization_step = global_step

(this is a problem in the prev code too)

@drhead
Copy link
Contributor Author

drhead commented Apr 16, 2024

I think this does need implementation in the training examples and further testing before merging, including profiling performance gains vs. standard for loop implementations and versus regular blocking CPU offload.

Could you elaborate this a bit? Do you mean we should expose arguments to the users via CLI args so that they can control EMA stuff with more granularity?

I do have a basic example for it in the SD1.5 text to image training script right now. I think that exposing foreach as a CLI arg is a bit superfluous, so I'm only exposing the non-blocking CPU offload for now since I expect that this would be far more useful for most users. Once I profile it I will see what the memory usage impact of foreach is to see if it is low enough that it would be an appropriate default.

@drhead
Copy link
Contributor Author

drhead commented Apr 17, 2024

So early testing on this version is showing that the foreach implementation is visibly faster even without having run a profiler on it, and doesn't seem to use any more VRAM (testing on SD1.5 training with batch size 1 and gradient checkpointing). Based on that, I think it is safe to set it to True by default.

The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.

@sayakpaul
Copy link
Member

So early testing on this version is showing that the foreach implementation is visibly faster even without having run a profiler on it, and doesn't seem to use any more VRAM (testing on SD1.5 training with batch size 1 and gradient checkpointing). Based on that, I think it is safe to set it to True by default.

I would add that as a CLI argument but not set it to True by default. We should make a note about this feature in our docs and let the users tinker around with it first. If there's sufficient reception, we can definitely change to True as its default.

The offloading is being a bit more troublesome with the standard training script and I'm not seeing the compute overlap that I expect, based on just looking at GPU usage. I know that it should work, since I have a training run active on another machine right now with almost nonstop 100% usage while offloading two EMA states at once. I'll have to profile it to see what differences matter here, the main ones are that my working config does not use Accelerate, is running on a headless server (my testing machine uses WSL), and uses a training loop wrapped in torch.compile and very carefully avoids materializing tensors on CPU.

Yeah would love to know the differences. I think not using accelerate is the primary difference here but of course, we need to know more and better.

@drhead
Copy link
Contributor Author

drhead commented Apr 18, 2024

Yeah would love to know the differences. I think not using accelerate is the primary difference here but of course, we need to know more and better.

This appears to have been correct. Accelerate's dataloader uses blocking transfers, where my other training script had pinned memory and exclusively non-blocking transfers. Submitted a patch that should fix that and make Accelerate's dataloader perform better in these situations: huggingface/accelerate#2685

While it is faster now, there is one remaining apparent performance issue where there are a bunch of cudaHostMalloc calls for some (but not all) of the offloaded params when it is initiating the DtoH transfer:
image

I strongly suspect this is a WSL issue, when our other training machine is free I will see if this issue happens there as well. Regardless of this issue though, in its current state this is faster than a blocking transfer would be.

@drhead
Copy link
Contributor Author

drhead commented Apr 20, 2024

After testing/profiling on a different machine I'm fairly confident that the non-blocking offload is working as well as a given environment/other parts of the code can permit it to work. Most significantly dataloader data transfers with Accelerate (patch submitted already) and the .item() calls on metrics being reported, as well as a few other things that probably really can't be/aren't worth reporting without wrapping the training loop in torch.compile(), and which are fairly out of scope (and torch.compile is currently far too unstable to include in an example script IMO).

I'll add foreach as a commandline arg later, and then will mark as ready, then can work on docs and possibly propagating changes to other example scripts.

@sayakpaul
Copy link
Member

Thank you for investigating. We can just with one example and then open it up to the community to follow your PR as a reference.

and the .item() calls on metrics being reported

How significant is this one?

@drhead
Copy link
Contributor Author

drhead commented Apr 20, 2024

and the .item() calls on metrics being reported

How significant is this one?

They're not a huge issue if it is being called at most once per gradient update, because it necessarily must complete the steps to actually materialize the loss value. I don't have profiling data since I knew it would be a problem and commented them out immediately when testing, but the main offenders would be the tqdm postfix update and the updating of train_loss every forward pass (which is fine if not using gradient accumulation but isn't letting the dispatch queue fill up as much as it could otherwise).

Mitigating it without losing too much functionality would look like:

  • change train_loss to initialize as a scalar tensor on the accelerator.device
  • if accelerate.gather is differentiable/isn't already detaching, we should detach the loss value being passed to it (unlikely that this is causing huge issues but this isn't a bad practice)
  • remove the .item() call for accumulating avg_loss into train_loss (this will now be dispatched like every other operation in the chain)
  • under the sync_gradients branch:
    • add .item() to train_loss in accelerator.log
    • move the set_postfix portion to the sync_gradients branch and change it to use train_loss.item() -- we already materialized it so doing it again can't hurt
    • remove the train_loss = 0 line and replace it with train_loss *= 0 after the set_postfix portion, because that lets us reuse the buffer

In my experiences, this causes nearly no perceptible overhead. You won't see the loss for every forward pass, but the loss for the gradient update is arguably what you actually want to be tracking anyways. I could set up a PR for this if interested.

@sayakpaul
Copy link
Member

Oh thanks so much. If this is not causing significant overheads, I would like to keep it as is for now.

@drhead drhead marked this pull request as ready for review April 20, 2024 16:06
@yiyixuxu
Copy link
Collaborator

@sayakpaul is this good to merge?

@sayakpaul
Copy link
Member

I assume not. But I will defer to @drhead for confirming.

@drhead
Copy link
Contributor Author

drhead commented Apr 23, 2024

I assume not. But I will defer to @drhead for confirming.

It's effectively complete except for linting

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to merge. We'd need to:

  • Fix the code linting issues.
  • Add a block about this feature in the README of train_text_to_image.py. This is quite beneficial.

Additionally, do you think we could add a test for this here: https://github.com/huggingface/diffusers/blob/main/tests/others/test_ema.py?

@drhead
Copy link
Contributor Author

drhead commented Apr 24, 2024

This is good to merge. We'd need to:

* Fix the code linting issues.

* Add a block about this feature in the README of `train_text_to_image.py`. This is quite beneficial.

Additionally, do you think we could add a test for this here: https://github.com/huggingface/diffusers/blob/main/tests/others/test_ema.py?

I have implemented the first two -- which tests are you interested in having? I would think running (nearly?) all of the tests on the foreach implementation to ensure they both pass would be appropriate, but if you think less than that is needed then let me know.

@sayakpaul
Copy link
Member

Thanks so much.

Sorry for not being clear about the tests.

I would think running (nearly?) all of the tests on the foreach implementation to ensure they both pass would be appropriate, but if you think less than that is needed then let me know.

That sounds good to me. We could have a separate test module, test_ema_for_each.py and add all the tests there. But I am okay if you rather create a separate EMAModelForEachTests class in the same test_ema.py script and add the tests there.

Comment on lines +374 to +375
[param.data for param in parameters],
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one question. Should we add non_blocking=True here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's much benefit to doing so. I expect that copy_to is going to be used only for validation and model saving and won't be used every step, so in my opinion there's little benefit to having a non-blocking transfer for something that'll probably be used at the absolute most every several minutes and more realistically hours apart. I've also had to troubleshoot a few issues with unexpected increased VRAM usage (presumably from tensors not being removed from memory fast enough) when switching between validation and training, so with that (combined with the risk that someone might do something like a non-blocking copy_to of an EMA state to the model's parameters and save the model, which might result in save being called on non-ready tensors) I think it is safer for these to just be blocking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. That sounds reasonable to me.

@sayakpaul
Copy link
Member

@drhead a gentle ping here :)

@sayakpaul
Copy link
Member

I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.

@drhead
Copy link
Contributor Author

drhead commented Jun 22, 2024

I am happy to merge this PR once the conflicts are resolved and a simple test suite is added. I feel the test suite is important here.

I admittedly haven't done anything with test suites in a very long time -- do I need to do anything other than merely adding the class like I've done?

@sayakpaul
Copy link
Member

All good. Will merge once the CI is green. Thanks a bunch for your contributions!

@sayakpaul
Copy link
Member

@bghira
Copy link
Contributor

bghira commented Jun 22, 2024

pin_memory needs to be blocked from running if torch.backends.mps.is_available()

@bghira
Copy link
Contributor

bghira commented Jun 22, 2024

2024-06-22 06:27:12,238 [ERROR] (__main__) Failed to pin EMA model to CPU: cannot pin 'MPSBFloat16Type' only dense CPU tensors can be pinned

@bghira
Copy link
Contributor

bghira commented Jun 22, 2024

i've been testing a variant of this.

  • we can add arg ema_cpu_only to keep EMA on CPU forever if we don't pin anything, but instead, use s_param.sub_(one_minus_decay * (s_param - param.to(s_param.device))) for calculation to move the base model param to the CPU instead of EMA to GPU. this didn't noticeably increase calculation runtime but does noticeably reduce vram
  • we can add arg ema_update_interval to only update eg. every 5-100 steps
  • save_pretrained needs max_shard_size added as an arg, and pass through to the base method

@sayakpaul
Copy link
Member

Perhaps these and the mps fix could be clubbed in a separate PR?

@bghira
Copy link
Contributor

bghira commented Jun 22, 2024

well, great work on the conceptual design of this change. on a macbook pro m3 max it improves iteration times such that EMA produces no visible impact on training speed, even if we never move it to the GPU.

@sayakpaul sayakpaul merged commit 2ada094 into huggingface:main Jun 24, 2024
@sayakpaul
Copy link
Member

Thanks a lot for working on this feature and for iterating on it.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
…s and better support for non-blocking CPU offloading (#7685)

* Add support for _foreach operations and non-blocking to EMAModel

* default foreach to false

* add non-blocking EMA offloading to SD1.5 T2I example script

* fix whitespace

* move foreach to cli argument

* linting

* Update README.md re: EMA weight training

* correct args.foreach_ema

* add tests for foreach ema

* code quality

* add foreach to from_pretrained

* default foreach false

* fix linting

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: drhead <a@a.a>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants