From 872040c201c1896228b5141aab13553ff9f7f1d0 Mon Sep 17 00:00:00 2001 From: charchit7 Date: Mon, 15 Jan 2024 17:13:48 +0530 Subject: [PATCH 1/2] changes for pix2pix_sdxl --- .../train_instruct_pix2pix_sdxl.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index de59cb1f0be9..d3a286a4fb5e 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -52,7 +52,7 @@ from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image from diffusers.utils.import_utils import is_xformers_available - +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.26.0.dev0") @@ -531,6 +531,11 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -1044,8 +1049,8 @@ def collate_fn(examples): added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} model_pred = unet( - concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ).sample + concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False + )[0] loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -1115,7 +1120,7 @@ def collate_fn(examples): # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unwrap_model(unet), text_encoder=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer=tokenizer_1, @@ -1177,7 +1182,7 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) From 818273c85a05dd714d68805a2dd349eb00cf45f3 Mon Sep 17 00:00:00 2001 From: charchit7 Date: Mon, 15 Jan 2024 17:17:50 +0530 Subject: [PATCH 2/2] style fix --- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index d3a286a4fb5e..7a158f5d0d2d 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -54,6 +54,7 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.26.0.dev0") @@ -1049,7 +1050,11 @@ def collate_fn(examples): added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} model_pred = unet( - concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False + concatenated_noisy_latents, + timesteps, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, )[0] loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")