From 956b2b2d4bbe5fd243ad30d09b1ba0515d7f6fb8 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Thu, 23 Mar 2023 23:18:43 +0800 Subject: [PATCH] Update train_text_to_image_lora.py --- examples/research_projects/lora/train_text_to_image_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 0ff15ed293e4..fe031df147a4 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -542,9 +542,9 @@ def main(): lora_layers = AttnProcsLayers(unet.attn_processors) # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available():