Skip to content

load_lora_weight does not work with torch.compile #5227

@ziniuwan

Description

@ziniuwan

Describe the bug

load_lora_weight does not work with torch.compile. StableDiffusionPipeline & StableDiffusionXLPipeline both failed.

It's fine in version 0.18.2, and fails since 0.19.0.

Reproduction

from diffusers import StableDiffusionPipeline
from PIL import Image
import torch
from safetensors.torch import load_file

pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16).to("cuda")
lora = load_file("myLora.safetensors", device="cuda")
pipe.load_lora_weights(lora)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

images = pipe(
        prompt="1girl",
        num_inference_steps=20,
    ).images

Logs

Traceback (most recent call last):
  File "/home/ziniu/code/test.py", line 13, in <module>
    images = pipe(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 680, in __call__
    noise_pred = self.unet(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1002, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 244, in call_function
    return tx.inline_user_function_return(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
    tracer.run()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1014, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 244, in call_function
    return tx.inline_user_function_return(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
    tracer.run()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 965, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 244, in call_function
    return tx.inline_user_function_return(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
    tracer.run()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 965, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 744, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 74, in call_method
    ).call_function(tx, [self.objvar] + args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 259, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 92, in call_function
    return tx.inline_user_function_return(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
    tracer.run()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 965, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 288, in call_function
    return self.obj.call_method(
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 522, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 253, in call_method
    raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method NNModuleVariable() _conv_forward [TensorVariable(), TensorVariable(), TensorVariable()] {}

from user code:
   File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/diffusers/models/lora.py", line 167, in forward
    return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
  File "/home/ziniu/.conda/envs/sdxl/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

System Info

  • diffusers version: 0.21.3
  • Platform: Linux-5.15.0-1022-aws-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Huggingface_hub version: 0.17.2
  • Transformers version: 4.33.2
  • Accelerate version: 0.23.0
  • xFormers version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@patrickvonplaten @sayakpaul @williamberman

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions