diff --git a/examples/lora_dreambooth/train_dreambooth.py b/examples/lora_dreambooth/train_dreambooth.py index c06a175ba8..9145ecab95 100644 --- a/examples/lora_dreambooth/train_dreambooth.py +++ b/examples/lora_dreambooth/train_dreambooth.py @@ -999,7 +999,10 @@ def main(args): pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None images = [] for _ in range(args.num_validation_images): image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] @@ -1050,15 +1053,17 @@ def main(args): if accelerator.is_main_process: if args.use_lora: lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + unwarpped_unet = accelerator.unwrap_model(unet) + state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet)) + lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True) if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(inference=True) accelerator.print(state_dict) accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt"))