-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The ControlNet training example (PyTorch variant) is failing when used with mixed-precision.
Here's the command I used:
accelerate launch train_controlnet.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--controlnet_model_name_or_path=lllyasviel/sd-controlnet-openpose \
--output_dir=sayak_pose_model --dataset_name=sayakpaul/poses-controlnet-dataset \
--resolution=512 --learning_rate=1e-5 \
--validation_image=sayak-pose.jpeg --validation_prompt "a man standing on a rock, a detailed high-quality professional image" \
--validation_steps=1000 \
--train_batch_size=1 --gradient_accumulation_steps=4 \
--report_to=wandb --max_train_steps=100000 \
--tracker_project_name=pose_train \
--checkpointing_steps=10000 \
--image_column=original_image --conditioning_image_column=condtioning_image --caption_column=caption \
--mixed_precision=fp16 \
--report_to=wandb
sayak-pose.jpeg
can be downloaded by:
wget https://datasets-server.huggingface.co/assets/sayakpaul/poses-controlnet-dataset/--/sayakpaul--poses-controlnet-dataset/train/2/condtioning_image/image.jpg -O sayak-pose.jpeg
Reproduction
Code snippet provided.
Logs
Traceback (most recent call last):
File "train_controlnet.py", line 1046, in <module>
main(args)
File "train_controlnet.py", line 971, in main
model_pred = unet(
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 679, in forward
sample = upsample_block(
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1891, in forward
hidden_states = resnet(hidden_states, temb)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/resnet.py", line 555, in forward
hidden_states = self.norm1(hidden_states)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 273, in forward
return F.group_norm(
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/functional.py", line 2528, in group_norm
return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Float but found Half
System Info
- `diffusers` version: 0.15.0.dev0
- Platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.10
- Python version: 3.8.16
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.13.2
- Transformers version: 4.26.1
- Accelerate version: 0.18.0
- xFormers version: 0.0.16
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working