From f91f6bd1ef954eef1f8fadea995b42e3db1a39a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 4 Apr 2023 09:06:38 +0530 Subject: [PATCH 1/8] fix: norm group test for UNet3D. --- tests/models/test_models_unet_3d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 729367a0c164..5a0d74a3ea5a 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -119,12 +119,11 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - # Overriding because `block_out_channels` needs to be different for this model. + # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 32 - init_dict["block_out_channels"] = (32, 64, 64, 64) model = self.model_class(**init_dict) model.to(torch_device) From 3249f116997426863fd461ec8814cd12125e5e12 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Apr 2023 12:20:18 +0530 Subject: [PATCH 2/8] fix: unet rejig. --- .../train_instruct_pix2pix.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index a119e12f73d1..77d7f7a7b26f 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -451,19 +451,18 @@ def main(): # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - if accelerator.is_main_process: - logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") - in_channels = 8 - out_channels = unet.conv_in.out_channels - unet.register_to_config(in_channels=in_channels) - - with torch.no_grad(): - new_conv_in = nn.Conv2d( - in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding - ) - new_conv_in.weight.zero_() - new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) - unet.conv_in = new_conv_in + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in # Freeze vae and text_encoder vae.requires_grad_(False) From 1065572e18d8b3e30f0203c603fcc2d2b4f5675e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 19:00:08 +0530 Subject: [PATCH 3/8] fix: unwrapping when running validation inputs. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 77d7f7a7b26f..28c635399112 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -894,6 +894,8 @@ def collate_fn(examples): pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), revision=args.revision, torch_dtype=weight_dtype, ) @@ -903,7 +905,9 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): for _ in range(args.num_validation_images): edited_images.append( pipeline( From 65ad3aead321444d5d9a746a774b759224708e9b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 19:26:15 +0530 Subject: [PATCH 4/8] unwrapping the unet too. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 28c635399112..f19f68818924 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -891,9 +891,10 @@ def collate_fn(examples): # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) + # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unet, + unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), vae=accelerator.unwrap_model(vae), revision=args.revision, From 6911aaae311fea449c4eed350950a2a196da4683 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 19:27:44 +0530 Subject: [PATCH 5/8] fix: device. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index f19f68818924..9de27d38b44f 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -963,7 +963,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline = pipeline.to(accelerator.device) - with torch.autocast(str(accelerator.device)): + with torch.autocast(str(accelerator.device).replace(":0", "")): for _ in range(args.num_validation_images): edited_images.append( pipeline( From 1eb59d9e73cb823ac39392184fcc9ddca22be7aa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 19:36:27 +0530 Subject: [PATCH 6/8] better unwrapping. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 9de27d38b44f..bd05ae049aac 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -889,12 +889,13 @@ def collate_fn(examples): # create pipeline if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) + unwrapped_unet = accelerator.unwrap_model(unet) + ema_unet.store(unwrapped_unet.parameters()) + ema_unet.copy_to(unwrapped_unet.parameters()) # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unwrapped_unet, text_encoder=accelerator.unwrap_model(text_encoder), vae=accelerator.unwrap_model(vae), revision=args.revision, From a61a99f78fcc9b7e162c5e9b81574d427dc0cea5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 19:41:38 +0530 Subject: [PATCH 7/8] unwrapping before ema. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index bd05ae049aac..8c2e92e016a3 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -886,10 +886,10 @@ def collate_fn(examples): f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + unwrapped_unet = accelerator.unwrap_model(unet) # create pipeline if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - unwrapped_unet = accelerator.unwrap_model(unet) ema_unet.store(unwrapped_unet.parameters()) ema_unet.copy_to(unwrapped_unet.parameters()) # The models need unwrapping because for compatibility in distributed training mode. From 06c4a6524399483bcd0eb1003323c70afd910a33 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 19:54:43 +0530 Subject: [PATCH 8/8] unwrapping. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 8c2e92e016a3..9de27d38b44f 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -886,16 +886,15 @@ def collate_fn(examples): f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - unwrapped_unet = accelerator.unwrap_model(unet) # create pipeline if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unwrapped_unet.parameters()) - ema_unet.copy_to(unwrapped_unet.parameters()) + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unwrapped_unet, + unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), vae=accelerator.unwrap_model(vae), revision=args.revision,