Skip to content

Commit

Permalink
Merge pull request #204 from haofanwang/patch-1
Browse files Browse the repository at this point in the history
Update train_dreambooth.py
  • Loading branch information
pacman100 committed Mar 24, 2023
2 parents d8c3b6b + b5b3ae3 commit 098962f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions examples/lora_dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,10 @@ def main(args):
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
else:
generator = None
images = []
for _ in range(args.num_validation_images):
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
Expand Down Expand Up @@ -1050,15 +1053,17 @@ def main(args):
if accelerator.is_main_process:
if args.use_lora:
lora_config = {}
state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
unwarpped_unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
if args.train_text_encoder:
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(
text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
)
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
state_dict.update(text_encoder_state_dict)
lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)
lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(inference=True)

accelerator.print(state_dict)
accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt"))
Expand Down

0 comments on commit 098962f

Please sign in to comment.