Skip to content

Commit 4ae0695

Browse files
committed
update
1 parent 610a71d commit 4ae0695

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,7 @@ def save_model_hook(models, weights, output_dir):
12901290
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
12911291
get_peft_model_state_dict(model)
12921292
)
1293+
else:
12931294
raise ValueError(f"unexpected save model: {model.__class__}")
12941295

12951296
# make sure to pop weight so that corresponding model is not saved again
@@ -1981,7 +1982,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19811982
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
19821983
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
19831984
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
1984-
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
1985+
save_file(kohya_state_dict, f"{args.output_dir}/kohya_lora_weights.safetensors")
19851986

19861987
save_model_card(
19871988
model_id if not args.push_to_hub else repo_id,

0 commit comments

Comments
 (0)