-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[torch.compile] fix graph break problems partially #5453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | ||
| # Forward upsample size to force interpolation output size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile() fails to compile these kinds of iterators right now.
|
The failing test seems unrelated to the PR. |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
| is_unet_compiled = is_compiled_module(self.unet) | ||
| is_controlnet_compiled = is_compiled_module(self.controlnet) | ||
| is_torch_higher_equal_than_2_1 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( | ||
| "2.1" | ||
| ) | ||
| with self.progress_bar(total=num_inference_steps) as progress_bar: | ||
| for i, t in enumerate(timesteps): | ||
| # Relevant thread: | ||
| # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 | ||
| if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: | ||
| torch._inductor.cudagraph_mark_step_begin() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this is in some sense a breaking change from PT, do we really have to add version specific code here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no other way to support compiled ControlNets otherwise in PT 2.1, sadly.
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
Show resolved
Hide resolved
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general,just some nits. Let's also flag this issue for PT to take a look. Torch compile seems to be backward broken here between 2.1 and 2.0
| for i, t in enumerate(timesteps): | ||
| # Relevant thread: | ||
| # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 | ||
| if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_than_2_1: | |
| if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: |
Do we need this really? It's not super pretty and looks like a bug in PT 2.1 . Also are we sure the code works fine with PT 2.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed internally.
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Outdated
Show resolved
Hide resolved
|
@patrickvonplaten let me know your final thoughts on merging this PR. Personally, I am okay hearing back from the PT folks and then deciding the course of action here. But I guess, it's good for us to be at least aware of the situation and a reasonable workaround. |
|
Ok to merge from my side. It'll take a while until this would be fixed in PT 2.1, so think no matter what we should merge this. Great job! |
* fix: controlnet graph? * fix: sample * fix: * remove print * styling * fix-copies * prevent more graph breaks? * prevent more graph breaks? * see? * revert. * compilation. * rpopagate changes to controlnet sdxl pipeline too. * add: clean version checking.
* fix: controlnet graph? * fix: sample * fix: * remove print * styling * fix-copies * prevent more graph breaks? * prevent more graph breaks? * see? * revert. * compilation. * rpopagate changes to controlnet sdxl pipeline too. * add: clean version checking.
* fix: controlnet graph? * fix: sample * fix: * remove print * styling * fix-copies * prevent more graph breaks? * prevent more graph breaks? * see? * revert. * compilation. * rpopagate changes to controlnet sdxl pipeline too. * add: clean version checking.
* fix: controlnet graph? * fix: sample * fix: * remove print * styling * fix-copies * prevent more graph breaks? * prevent more graph breaks? * see? * revert. * compilation. * rpopagate changes to controlnet sdxl pipeline too. * add: clean version checking.
Fixes graph break problems for T2I Adapters (both SD and SDXL).
ControlNets are still failing. I am trying to get to the bottom of it.@DN6, if you want to double-check the fixes proposed in this PR, I'd appreciate it.