Skip to content

Conversation

@dg845
Copy link
Collaborator

@dg845 dg845 commented Jan 26, 2024

What does this PR do?

This PR enables gradient checkpointing for UNet2DModel by setting the _supports_gradient_checkpointing flag to True. Since UNet2DConditionModel has _supports_gradient_checkpointing = True, it seems like UNet2DModel should support gradient checkpointing as well, unless I'm missing something.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten
@sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have also configure the gradient checkpointing blocks like how we do here?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dg845
Copy link
Collaborator Author

dg845 commented Jan 27, 2024

Do we not have also configure the gradient checkpointing blocks like how we do here?

You're right, I missed this 😅.

@dg845
Copy link
Collaborator Author

dg845 commented Jan 27, 2024

The UNetMidBlock2D, AttnDownBlock2D, and AttnUpBlock2D blocks used in UNet2DModel currently do not have gradient checkpointing implemented, so I have added gradient checkpointing to each of those blocks.

As a note, in their current forward method, AttnDownBlock2D and AttnUpBlock2D use the scale keyword argument when calling resnet, e.g.:

cross_attention_kwargs.update({"scale": lora_scale})
hidden_states = resnet(hidden_states, temb, scale=lora_scale)

So I have written the create_custom_forward function as

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs, **kwargs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict, **kwargs)
else:
return module(*inputs, **kwargs)
return custom_forward

This has the potential to cause problems if return_dict is also supplied through kwargs (see https://docs.python.org/3/tutorial/controlflow.html#function-examples).

CrossAttnDownBlock2D handles this by omitting the scale argument when calling custom_forward entirely:

hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)

which seems wrong when lora_scale is not ResnetBlock2D.forward's default scale value of 1.0 because the forward passes with and without gradient checkpointing are not equivalent.

Since scale is currently the third keyword argument to ResnetBlock2D.forward, we can probably supply it as a positional argument and use CrossAttnDownBlock2D's create_custom_forward implementation. I'm not sure which approach is best.

@sayakpaul
Copy link
Member

which seems wrong when lora_scale is not ResnetBlock2D.forward's default scale value of 1.0 because the forward passes with and without gradient checkpointing are not equivalent.

I think this is still fine because lora_scale is never going to interfere during training. This is still reasonable to me because there is always a bit of subtle differences for training and inference in the forward of a model. WDYT?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nits, but looks very good. Nice test, too.

dg845 added 5 commits January 27, 2024 16:22
… positional arg when gradient checkpointing for AttnDownBlock2D/AttnUpBlock2D.
…checkpointing for CrossAttnDownBlock2D/CrossAttnUpBlock2D as well.
@dg845
Copy link
Collaborator Author

dg845 commented Jan 28, 2024

Regarding #6718 (comment): I think in this case the best short term solution is to use the standard create_custom_forward implementation and supply the lora_scale as a positional argument when calling resnet using the custom_forward function. This allows the gradient checkpointing forward pass to be the same as the non-gradient checkpointing forward pass during training (which is also the same as the forward pass during inference). [Note that this is implemented in e0ee9ca and 8756be5]. This does introduce some dependency on the order of the positional arguments in ResnetBlock2D.forward, but I think it's probably fine since the ResnetBlock2D API is likely to remain stable over time.

In the long term, at least in src/diffusers/models/unets/unet_2d_blocks.py, I think it might make sense to revisit the create_custom_forward implementation. A quick search through that file indicates that create_custom_forward is never called with the return_dict keyword argument. So perhaps a more general implementation like

def create_custom_forward(module):
    def custom_forward(*inputs, **kwargs):
        return module(*inputs, **kwargs)

    return custom_forward

could be used, and return_dict could be supplied through kwargs if necessary: for example, if we're using torch.utils.checkpoint.checkpoint, something like

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
    create_custom_forward(resnet),
    hidden_states,
    temb,
    scale=lora_scale,
    return_dict=True,
    **ckpt_kwargs,
)

Note that torch.utils.checkpoint.checkpoint supports both positional and keyword arguments, at least since v1.11.

@sayakpaul
Copy link
Member

@dg845 I think I don't quite follow the concern fully.

Could you maybe try to demonstrate the issue with a simpler example?

which seems wrong when lora_scale is not ResnetBlock2D.forward's default scale value of 1.0 because the forward passes with and without gradient checkpointing are not equivalent.

Would like to see when this case arises. From what I understand, gradient checkpointing is used during training, and lora_scale is never supposed to be supplied during training. So, I don't quite understand how a discrepancy stems here. Maybe I am missing something.

I would like to keep the legacy blocks as is until and unless absolutely necessary. This is why I am asking for a simpler example to understand the consequences.

@dg845
Copy link
Collaborator Author

dg845 commented Jan 28, 2024

which seems wrong when lora_scale is not ResnetBlock2D.forward's default scale value of 1.0 because the forward passes with and without gradient checkpointing are not equivalent.

Sorry, I should have made it clear that the above follows from my belief that the lora_scale should be supplied during training.

My understanding is that in the original LoRA paper the LoRA scale parameter $\alpha$ is a hyperparameter during training:

lora_alpha_param

I think in practice $\alpha$ is typically held constant and the learning rate is tuned during training (following the highlighted section), but theoretically we could treat $\alpha$ and the learning rate as independent hyperparameters and tune them both.

Similarly, if we look at peft.tuners.lora.layer.Linear, the forward method does not disable scaling the LoRA update $\Delta W$ during training:

https://github.com/huggingface/peft/blob/bfc102c0c095dc9094cdd3523b729583bfad4688/src/peft/tuners/lora/layer.py#L318-L320

unlike for something like dropout where the forward pass would be different depending on whether torch.nn.Module.training is set.

So in my view the discrepancy between the gradient checkpointing code and non-gradient checkpointing code in e.g. CrossAttnDownBlock2D when lora_scale != 1.0 is a bug because

  1. The gradient checkpointing forward pass differs from the non-gradient checkpointing forward pass during training.
  2. For the same reason, the gradient checkpointing forward pass during training differs from the forward pass during inference, which will result in a train-test mismatch if we train using gradient checkpointing.

Practically speaking, we might not consider the train-test mismatch that arises to be that bad, since we may want to tune the scaling of the LoRA update during inference anyway (e.g. if we are performing inference with multiple LoRAs simultaneously).

@dg845
Copy link
Collaborator Author

dg845 commented Jan 28, 2024

That being said, perhaps it's better if I move the changes (especially to CrossAttnDownBlock2D and CrossAttnUpBlock2D in 8756be5) out of this PR, and revisit this in a separate issue/PR.

@sayakpaul
Copy link
Member

alpha is not the same as the scale parameter in LoRA training in my understanding. scale is an inference-time parameter and shouldn't influence training whereas alpha could be tuned during training of LoRAs. For LoRA training, we rely on PEFT. If we want to expose alpha for training we can easily do so. Example:

But this discussion is starting to deviate from the original topic of the PR a bit IMO.

The gradient checkpointing forward pass differs from the non-gradient checkpointing forward pass during training.
For the same reason, the gradient checkpointing forward pass during training differs from the forward pass during inference, which will result in a train-test mismatch if we train using gradient checkpointing.

^ this I agree. And maybe this could be handled first in a separate PR and then we revisit this PR. Does that work?

@dg845
Copy link
Collaborator Author

dg845 commented Jan 29, 2024

And maybe this could be handled first in a separate PR and then we revisit this PR. Does that work?

Sounds good :). To be more precise, would something like this sound good to you?

  1. In this PR, gradient checkpointing is implemented for UNet2DModel and its associated blocks such as AttnDownBlock2D/AttnUpBlock2D in a way which is exactly parallel to the current gradient checkpointing implementation in UNet2DConditionModel and CrossAttnDownBlock2D/CrossAttnUpBlock2D.
  2. The question of how the scale parameter should be handled for the legacy LoRA implementation will be revisited in a separate issue and/or PR.

dg845 added 2 commits January 28, 2024 18:19
…radient checkpointing for CrossAttnDownBlock2D/CrossAttnUpBlock2D as well."

This reverts commit 8756be5.
…ions exactly parallel to CrossAttnDownBlock2D/CrossAttnUpBlock2D implementations.
@sayakpaul
Copy link
Member

Yeah that is right.

@dg845
Copy link
Collaborator Author

dg845 commented Jan 30, 2024

I have updated the gradient checkpointing implementation in this PR to be exactly parallel to that of UNet2DConditionModel and opened a new issue regarding the LoRA scale parameter scale at #6756.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a mile!

woshiyyya and others added 19 commits February 2, 2024 14:53
…a movement. (huggingface#6704)

* load cumprod tensor to device

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>

* fixing ci

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>

* make fix-copies

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>

---------

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
…uggingface#6736)

Fix bug in ResnetBlock2D.forward when not USE_PEFT_BACKEND and using scale_shift for time emb where the lora scale  gets overwritten.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Update train_diffusion_dpo.py

Address huggingface#6702

* Update train_diffusion_dpo_sdxl.py

* Empty-Commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…ss (huggingface#6762)

* add is_flaky to test_model_cpu_offload_forward_pass

* style

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* update

* update

* updaet

* add tests and docs

* clean up

* add to toctree

* fix copies

* pr review feedback

* fix copies

* fix tests

* update docs

* update

* update

* update docs

* update

* update

* update

* update
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
move sigma to device

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
* add

* remove transformer

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
…gface#6738)

* harmonize the module structure for models in tests

* make the folders modules.

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update testing_utils.py

* Update testing_utils.py
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 27, 2024
@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Feb 27, 2024
@sayakpaul
Copy link
Member

I think the PR is borked. Should we open a new PR instead? @dg845

dg845 added a commit to dg845/diffusers that referenced this pull request Mar 4, 2024
@dg845
Copy link
Collaborator Author

dg845 commented Mar 4, 2024

Created a new PR with the changes at #7201. Will close this PR.

@dg845 dg845 closed this Mar 4, 2024
yiyixuxu pushed a commit that referenced this pull request Dec 20, 2024
* Port UNet2DModel gradient checkpointing code from #6718.


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* Port UNet2DModel gradient checkpointing code from huggingface#6718.


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Port UNet2DModel gradient checkpointing code from #6718.


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.