From f91f6bd1ef954eef1f8fadea995b42e3db1a39a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 4 Apr 2023 09:06:38 +0530 Subject: [PATCH 1/2] fix: norm group test for UNet3D. --- tests/models/test_models_unet_3d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 729367a0c164..5a0d74a3ea5a 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -119,12 +119,11 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - # Overriding because `block_out_channels` needs to be different for this model. + # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 32 - init_dict["block_out_channels"] = (32, 64, 64, 64) model = self.model_class(**init_dict) model.to(torch_device) From 64bd32053feba0603b5170e79251a5d1b352d470 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 10:04:39 +0530 Subject: [PATCH 2/2] fix: type-casting issue in controlnet training. --- examples/controlnet/train_controlnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 20c4fbe189a1..52a69ac05e7a 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -972,8 +972,10 @@ def load_model_hook(models, input_dir): noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type