From bbc91791bf41b93af64610886f792c403aac00ed Mon Sep 17 00:00:00 2001 From: timdalxx <48753785+jeongiin@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:16:02 +0900 Subject: [PATCH 1/3] fix the issue on flux dreambooth lora training --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6091622719ee..ddb0d0b12bb6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -183,7 +183,8 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - autocast_ctx = nullcontext() + # autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type, dtype=torch_dtype) with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] From 63840b759a3adcd71eb1d1b553854f488f77f33c Mon Sep 17 00:00:00 2001 From: timdalxx <48753785+jeongiin@users.noreply.github.com> Date: Sat, 5 Oct 2024 14:45:59 +0900 Subject: [PATCH 2/3] update : origin main code --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b3bbb5b53ab4..fcc11386abcf 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -183,8 +183,7 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - # autocast_ctx = nullcontext() - autocast_ctx = torch.autocast(accelerator.device.type, dtype=torch_dtype) + autocast_ctx = nullcontext() with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] From d2a0b122809ea610c2473ce2161dfe4387791808 Mon Sep 17 00:00:00 2001 From: timdalxx <48753785+jeongiin@users.noreply.github.com> Date: Sun, 6 Oct 2024 03:12:12 +0900 Subject: [PATCH 3/3] fix the issue on flux dreambooth lora training --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fcc11386abcf..b3bbb5b53ab4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -183,7 +183,8 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - autocast_ctx = nullcontext() + # autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type, dtype=torch_dtype) with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]