diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fcc11386abcf..b3bbb5b53ab4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -183,7 +183,8 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - autocast_ctx = nullcontext() + # autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type, dtype=torch_dtype) with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]