Skip to content

RuntimeError: torch._dynamo.optimize is called on a non function object. #2775

@xiaosaxuexige

Description

@xiaosaxuexige

Describe the bug

When I am trying to run the example given by the authors, which is on the pokemon dataset. I just follows the instruction on huggingface and never revise the code. but I got this error:

RuntimeError:

torch._dynamo.optimize is called on a non function object.
If this is a callable class, please wrap the relevant code into a function and optimize the
wrapper function.

Reproduction

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16" train_text_to_image.py
--pretrained_model_name_or_path=$MODEL_NAME
--dataset_name=$dataset_name
--use_ema
--resolution=512 --center_crop --random_flip
--train_batch_size=1
--gradient_accumulation_steps=4
--gradient_checkpointing
--max_train_steps=15000
--learning_rate=1e-05
--max_grad_norm=1
--lr_scheduler="constant" --lr_warmup_steps=0
--output_dir="sd-pokemon-model"

Logs

03/22/2023 14:32:42 - INFO - __main__ - ***** Running training *****
03/22/2023 14:32:42 - INFO - __main__ -   Num examples = 833
03/22/2023 14:32:42 - INFO - __main__ -   Num Epochs = 72
03/22/2023 14:32:42 - INFO - __main__ -   Instantaneous batch size per device = 1
03/22/2023 14:32:42 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 4
03/22/2023 14:32:42 - INFO - __main__ -   Gradient Accumulation steps = 4
03/22/2023 14:32:42 - INFO - __main__ -   Total optimization steps = 15000
Steps:   0%|                                                                                                                                  | 0/15000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train_text_to_image.py", line 789, in <module>
    main()
  File "train_text_to_image.py", line 730, in main
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  File "/home/wangfuling/wangfuling/ENTER/envs/dfs/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wangfuling/wangfuling/ENTER/envs/dfs/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/wangfuling/wangfuling/ENTER/envs/dfs/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 229, in __call__
    raise RuntimeError(
RuntimeError: 

torch._dynamo.optimize is called on a non function object.
If this is a callable class, please wrap the relevant code into a function and optimize the
wrapper function.

>> class CallableClass:
>>     def __init__(self):
>>         super().__init__()
>>         self.relu = torch.nn.ReLU()
>>
>>     def __call__(self, x):
>>         return self.relu(torch.sin(x))
>>
>>     def print_hello(self):
>>         print("Hello world")
>>
>> mod = CallableClass()

If you want to optimize the __call__ function and other code, wrap that up in a function

>> def wrapper_fn(x):
>>     y = mod(x)
>>     return y.sum()

and then optimize the wrapper_fn

>> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)

System Info

  • diffusers version: 0.15.0.dev0
  • Platform: Linux-3.10.0-1160.66.1.el7.x86_64-x86_64-with-glibc2.10
  • Python version: 3.8.0
  • PyTorch version (GPU?): 2.0.0+cu117 (True)
  • Huggingface_hub version: 0.13.2
  • Transformers version: 4.27.1
  • Accelerate version: 0.17.1
  • xFormers version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions