From 3a8aa62c2587b8de571b218d98faeffd05873579 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 1 Jul 2024 09:55:51 +0000 Subject: [PATCH] update --- .../train_dreambooth_lora_sd15_advanced.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 02c7f8bb2887..fea145d0b1e3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -1856,10 +1856,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - if torch.backends.mps.is_available(): - autocast_ctx = nullcontext() - else: - autocast_ctx = torch.autocast(accelerator.device.type) + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) with autocast_ctx: images = [ @@ -1880,7 +1880,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ] } ) - del pipeline torch.cuda.empty_cache()