Skip to content

Array shape and Mixed Precision errors in examples/controlnet/train_controlnet.py script #2908

@jeromeku

Description

@jeromeku

Describe the bug

2 bugs:

Array shape:

  • During the log_validation function, when logging validation images to tensorboard, the call to np.array on the PIL validation image is converting to an array of shape [512, 512, 4] for conditioning_image_1.png while the other images are of shape [512, 512, 3]. The subsequent call to np.stack then causes an ValueError('all input arrays must have the same shape')
  • Fix is to change formatted_images.append(np.asarray(validation_image)) to formatted_images.append(np.asarray(validation_image)[...,:3])

Mixed dtype:

  • Calling the script with mixed precision:
    accelerate launch train_controlnet.py --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 --dataset_name=fusing/fill50k --resolution=512 --learning_rate=1e-5 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" --validation_prompt "red circle with blue background" "cyan circle with brown floral background" --report_to="tensorboard" --max_train_steps=2 --validation_steps=1 --mixed_precision=fp16
    causes the following:
    Traceback (most recent call last): File "/notebooks/sandbox/controlnet-finetuning/train_controlnet.py", line 1575, in <module> main(args) File "/notebooks/sandbox/controlnet-finetuning/train_controlnet.py", line 1496, in main model_pred = unet( File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.9/dist-packages/diffusers/models/unet_2d_condition.py", line 679, in forward sample = upsample_block( File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.9/dist-packages/diffusers/models/unet_2d_blocks.py", line 1891, in forward hidden_states = resnet(hidden_states, temb) File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.9/dist-packages/diffusers/models/resnet.py", line 555, in forward hidden_states = self.norm1(hidden_states) File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/normalization.py", line 273, in forward return F.group_norm( File "/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py", line 2530, in group_norm return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: mixed dtype (CPU): expect parameter to have scalar type of Float

  • Can be fixed by wrapping the call to unet during the forward training pass with accelerate.autocast().

  • Not sure why this is happening as the unet (as well as vae and tokenizer) are explicitly converted to the proper mixed precision type earlier in the script

Reproduction

For tensorboard logging error:
accelerate launch train_controlnet.py --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 --dataset_name=fusing/fill50k --resolution=512 --learning_rate=1e-5 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" --validation_prompt "red circle with blue background" "cyan circle with brown floral background" --report_to="tensorboard" --max_train_steps=2 --validation_steps=1

For mixed dtype error:
accelerate launch train_controlnet.py --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 --dataset_name=fusing/fill50k --resolution=512 --learning_rate=1e-5 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" --validation_prompt "red circle with blue background" "cyan circle with brown floral background" --report_to="tensorboard" --max_train_steps=2 --validation_steps=1 --mixed_precision=fp16

Logs

No response

System Info

diffusers 0.15.0.dev0
numpy 1.23.1
torch 2.0.0
transformers 4.27.4

python version 3.9.13

NVIDIA-SMI 510.73.05
Driver Version: 510.73.05
CUDA Version: 11.6
GPU: NVIDIA A100-SXM4-80GB

Metadata

Metadata

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions