From 16dc790bc6b3d8b2be3b01db9cfa916e63075ce5 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 1 Sep 2025 14:25:56 +0300 Subject: [PATCH 1/4] propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script --- .../dreambooth/train_dreambooth_lora_flux.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2353625c3878..fead3549d389 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1131,8 +1131,17 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - pipeline = FluxPipeline.from_pretrained( + + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=torch_dtype, + ) + pipeline = FluxKontextPipeline.from_pretrained( args.pretrained_model_name_or_path, + transformer=transformer, torch_dtype=torch_dtype, revision=args.revision, variant=args.variant, @@ -1149,9 +1158,10 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() @@ -1159,8 +1169,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1728,6 +1737,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): device=accelerator.device, prompt=args.instance_prompt, ) + else: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) # Convert images to latent space if args.cache_latents: From 8c305db897cb71d8f87e753588094b7cb5d5b3be Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 1 Sep 2025 14:31:49 +0300 Subject: [PATCH 2/4] propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script --- .../train_dreambooth_lora_flux_advanced.py | 10 +++++++++- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 951b989d7a65..a08ac90e420b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1399,6 +1399,13 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=torch_dtype, + ) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -1419,7 +1426,8 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fead3549d389..23a88212bfeb 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1139,7 +1139,7 @@ def main(args): variant=args.variant, torch_dtype=torch_dtype, ) - pipeline = FluxKontextPipeline.from_pretrained( + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer, torch_dtype=torch_dtype, From 026a810b74a26af60b16598c659be0c02393559b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 1 Sep 2025 14:50:46 +0300 Subject: [PATCH 3/4] propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script --- .../train_dreambooth_lora_flux_advanced.py | 8 +------- examples/dreambooth/train_dreambooth_lora_flux.py | 8 -------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index a08ac90e420b..a46490e8b3bf 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1399,13 +1399,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, - torch_dtype=torch_dtype, - ) + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 23a88212bfeb..6c5fdeb28cbb 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1132,16 +1132,8 @@ def main(args): elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, - torch_dtype=torch_dtype, - ) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, - transformer=transformer, torch_dtype=torch_dtype, revision=args.revision, variant=args.variant, From db87b08bfcdc47123dfcffaddaaf95ef93a4cc34 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 1 Sep 2025 11:52:39 +0000 Subject: [PATCH 4/4] Apply style fixes --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6c5fdeb28cbb..bd3a974a17d8 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1150,7 +1150,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): images = pipeline(prompt=example["prompt"]).images