- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6.5k
 
Enable Gradient Checkpointing for UNet2DModel #6718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Gradient Checkpointing for UNet2DModel #6718
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not have also configure the gradient checkpointing blocks like how we do here?
| 
           The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.  | 
    
          
 You're right, I missed this 😅.  | 
    
… and AttnUpBlock2D.
…ard for gradient checkpointing in AttnDownBlock2D and AttnUpBlock2D.
| 
           The  As a note, in their current  diffusers/src/diffusers/models/unets/unet_2d_blocks.py Lines 1045 to 1046 in d4c7ab7 
 So I have written the  diffusers/src/diffusers/models/unets/unet_2d_blocks.py Lines 1072 to 1079 in e837857 
 This has the potential to cause problems if  
 diffusers/src/diffusers/models/unets/unet_2d_blocks.py Lines 1183 to 1188 in d4c7ab7 
 which seems wrong when  Since   | 
    
          
 I think this is still fine because   | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits, but looks very good. Nice test, too.
… positional arg when gradient checkpointing for AttnDownBlock2D/AttnUpBlock2D.
…checkpointing for CrossAttnDownBlock2D/CrossAttnUpBlock2D as well.
| 
           Regarding #6718 (comment): I think in this case the best short term solution is to use the standard  In the long term, at least in  def create_custom_forward(module):
    def custom_forward(*inputs, **kwargs):
        return module(*inputs, **kwargs)
    return custom_forwardcould be used, and  ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
    create_custom_forward(resnet),
    hidden_states,
    temb,
    scale=lora_scale,
    return_dict=True,
    **ckpt_kwargs,
)Note that   | 
    
| 
           @dg845 I think I don't quite follow the concern fully. Could you maybe try to demonstrate the issue with a simpler example? 
 Would like to see when this case arises. From what I understand, gradient checkpointing is used during training, and  I would like to keep the legacy blocks as is until and unless absolutely necessary. This is why I am asking for a simpler example to understand the consequences.  | 
    
          
 Sorry, I should have made it clear that the above follows from my belief that the  My understanding is that in the original LoRA paper the LoRA scale parameter   
I think in practice  Similarly, if we look at  unlike for something like dropout where the forward pass would be different depending on whether  So in my view the discrepancy between the gradient checkpointing code and non-gradient checkpointing code in e.g.  
 Practically speaking, we might not consider the train-test mismatch that arises to be that bad, since we may want to tune the scaling of the LoRA update during inference anyway (e.g. if we are performing inference with multiple LoRAs simultaneously).  | 
    
| 
           That being said, perhaps it's better if I move the changes (especially to   | 
    
| 
           
 
 But this discussion is starting to deviate from the original topic of the PR a bit IMO. 
 ^ this I agree. And maybe this could be handled first in a separate PR and then we revisit this PR. Does that work?  | 
    
          
 Sounds good :). To be more precise, would something like this sound good to you? 
  | 
    
…radient checkpointing for CrossAttnDownBlock2D/CrossAttnUpBlock2D as well." This reverts commit 8756be5.
…ions exactly parallel to CrossAttnDownBlock2D/CrossAttnUpBlock2D implementations.
| 
           Yeah that is right.  | 
    
| 
           I have updated the gradient checkpointing implementation in this PR to be exactly parallel to that of   | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a mile!
…a movement. (huggingface#6704) * load cumprod tensor to device Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> * fixing ci Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> * make fix-copies Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> --------- Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
…uggingface#6736) Fix bug in ResnetBlock2D.forward when not USE_PEFT_BACKEND and using scale_shift for time emb where the lora scale gets overwritten. Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Update train_diffusion_dpo.py Address huggingface#6702 * Update train_diffusion_dpo_sdxl.py * Empty-Commit --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…ss (huggingface#6762) * add is_flaky to test_model_cpu_offload_forward_pass * style * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
add ipo and hinge loss to dpo trainer
* update * update * updaet * add tests and docs * clean up * add to toctree * fix copies * pr review feedback * fix copies * fix tests * update docs * update * update * update docs * update * update * update * update
add missing param
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
move sigma to device Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
--------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
* add * remove transformer --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
…gface#6738) * harmonize the module structure for models in tests * make the folders modules. --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update testing_utils.py * Update testing_utils.py
| 
           This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored.  | 
    
| 
           I think the PR is borked. Should we open a new PR instead? @dg845  | 
    
| 
           Created a new PR with the changes at #7201. Will close this PR.  | 
    
* Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
* Port UNet2DModel gradient checkpointing code from huggingface#6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
* Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>

What does this PR do?
This PR enables gradient checkpointing for
UNet2DModelby setting the_supports_gradient_checkpointingflag toTrue. SinceUNet2DConditionModelhas_supports_gradient_checkpointing = True, it seems likeUNet2DModelshould support gradient checkpointing as well, unless I'm missing something.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten
@sayakpaul