diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 20c4fbe189a1..52a69ac05e7a 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -972,8 +972,10 @@ def load_model_hook(models, input_dir): noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type