Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

excessive graph breaks on attention.py and attention_processor.py for control_net on torch.compile #3218

Closed
shingjan opened this issue Apr 25, 2023 · 7 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@shingjan
Copy link

Describe the bug

I tried to run the controlnet example from this blog post and it turned out that the BasicTransformerBlock is causing a large number of graph breaks (>100) on a single controlnet pipeline. Ideally the whole BasicTransformerBlock.forward should be include in one single frame for speedups. The exact reason for the graph breaks is:

call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)}

for both self attention and cross attention. Is there a way to reduce the graph breaks to make StableDiffusionControlNetPipeline working better with torch.compile?

Reproduction

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import cv2
from PIL import Image
import torch
import numpy as np

image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)

image = np.array(image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

import torch
import torch._dynamo as dynamo

@dynamo.optimize("inductor")
def generate(prompt):
    generator = [torch.Generator(device="cuda").manual_seed(2) for i in range(len(prompt))]
    return pipe(
        prompt,
        canny_image,
        negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"] * len(prompt),
        num_inference_steps=10,
        generator=generator,
    )

prompt = ", best quality, extremely detailed"
prompt = [t + prompt for t in ["Sandra Oh", "Kim Kardashian", "rihanna", "taylor swift"]]
ex = dynamo.explain(generate, prompt)[-1]
print(ex)

### Logs

```shell
graph #169 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': ConstantVariable(NoneType), 'attention_mask': ConstantVariable(NoneType)} after 3
stack:   File "/home/yj/diffusers/src/diffusers/models/attention.py", line 313, in forward
    attn_output = self.attn1(
  File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
    return self.processor(
 
graph #171 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)} after 1
stack:   File "/home/yj/diffusers/src/diffusers/models/attention.py", line 331, in <resume in forward>
    attn_output = self.attn2(
  File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
    return self.processor(

System Info

Ubuntu 20.04 with cuda 11.8

diffusers 0.16.0.dev0 /home/yj/diffusers
torch 2.1.0a0+git0bbf8a9 /home/yj/pytorch

@shingjan shingjan added the bug Something isn't working label Apr 25, 2023
@sayakpaul
Copy link
Member

Cc: @pcuenca

@patrickvonplaten
Copy link
Contributor

@shingjan, we advise to only optimize the unet part of the pipeline with torch inductor could you instead try:

pipe.unet = torch.compile(pipe.unet, backend='inductor')

Also see: https://huggingface.co/docs/diffusers/optimization/torch2.0#using-accelerated-transformers-and-torchcompile

@shingjan
Copy link
Author

@patrickvonplaten thanks for the response! I think AttnProcessor2_0 is heavily used in unet so even if only unet is decorated, the graph breaks persist.

@patrickvonplaten
Copy link
Contributor

I don't fully understand this, what exactly is the issue here? Can we reproduce it somehow?

@patrickvonplaten
Copy link
Contributor

@shingjan,

This might help solve it actually: #3286

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 25, 2023
@shingjan
Copy link
Author

@patrickvonplaten Sorry for the late reply. Yes I did a rebase and most of the graph breaks seen on diffusers==0.16.1 is gone. The maybe_allow_in_graph and is_compiled are very useful. Closed this one as fixed by #3286

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

3 participants