Skip to content

[Examples] ControlNet (PyTorch) example fails when using mixed-precision #2991

@sayakpaul

Description

@sayakpaul

Describe the bug

The ControlNet training example (PyTorch variant) is failing when used with mixed-precision.

Here's the command I used:

accelerate launch train_controlnet.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path=lllyasviel/sd-controlnet-openpose \
--output_dir=sayak_pose_model --dataset_name=sayakpaul/poses-controlnet-dataset \
--resolution=512 --learning_rate=1e-5 \
--validation_image=sayak-pose.jpeg --validation_prompt "a man standing on a rock, a detailed high-quality professional image" \
--validation_steps=1000 \
--train_batch_size=1 --gradient_accumulation_steps=4 \
--report_to=wandb --max_train_steps=100000 \
--tracker_project_name=pose_train \
--checkpointing_steps=10000 \
--image_column=original_image --conditioning_image_column=condtioning_image --caption_column=caption \
--mixed_precision=fp16 \
--report_to=wandb

sayak-pose.jpeg can be downloaded by:

wget https://datasets-server.huggingface.co/assets/sayakpaul/poses-controlnet-dataset/--/sayakpaul--poses-controlnet-dataset/train/2/condtioning_image/image.jpg -O sayak-pose.jpeg

Reproduction

Code snippet provided.

Logs

Traceback (most recent call last):
  File "train_controlnet.py", line 1046, in <module>
    main(args)
  File "train_controlnet.py", line 971, in main
    model_pred = unet(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 679, in forward
    sample = upsample_block(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1891, in forward
    hidden_states = resnet(hidden_states, temb)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/resnet.py", line 555, in forward
    hidden_states = self.norm1(hidden_states)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 273, in forward
    return F.group_norm(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/functional.py", line 2528, in group_norm
    return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Float but found Half

System Info

- `diffusers` version: 0.15.0.dev0
- Platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.10
- Python version: 3.8.16
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.13.2
- Transformers version: 4.26.1
- Accelerate version: 0.18.0
- xFormers version: 0.0.16
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions