Skip to content

Commit

Permalink
MOD: merge kohya lora
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Dec 22, 2023
1 parent 6871e5e commit 86f139d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
4 changes: 2 additions & 2 deletions train_dreambooth_ziplora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def save_model_card(
## Trigger words
You should use a combination of {instance_prompt} and {instance_prompt_2} to trigger the image generation
You should use a combination of {instance_prompt} and {instance_prompt_2} to trigger the image generation
using the two trained concepts.
## Download model
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def main(args):
)
attn_module.to_out[0].set_lora_layer(
initialize_ziplora_layer(
part="to_out.0",
part="to_out_0",
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
init_merger_value=args.init_merger_value,
Expand Down
26 changes: 20 additions & 6 deletions ziplora_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,26 @@ def merge_lora_weights(
key (str): target attn layer's key
prefix (str, optional): prefix for state dict. Defaults to "unet.unet.".
"""
target_key = prefix + key
# target_key = prefix + key

splts = key.split('.')
block = splts[0]

if block == 'down_blocks':
block, n1, t1, n2, t2, n3, t3 = splts
kohya_key = f'lora_unet_input_blocks_{3 * int(n1) + int(n2) + 1}_1_{t2}_{n3}_{t3}'
elif block == 'up_blocks':
block, n1, t1, n2, t2, n3, t3 = splts
kohya_key = f'lora_unet_output_blocks_{3 * int(n1) + int(n2)}_1_{t2}_{n3}_{t3}'
elif block == 'mid_block':
block, t1, n1, t2, n2, t3 = splts
kohya_key = f'lora_unet_middle_block_1_{t2}_{n2}_{t3}'

out = {}
for part in ["to_q", "to_k", "to_v", "to_out.0"]:
down_key = target_key + f".{part}.lora.down.weight"
up_key = target_key + f".{part}.lora.up.weight"
merged_weight = tensors[up_key] @ tensors[down_key]
for part in ["to_q", "to_k", "to_v", "to_out_0"]:
down_key = kohya_key + f"_{part}.lora_down.weight"
up_key = kohya_key + f"_{part}.lora_up.weight"
merged_weight = tensors[up_key].float() @ tensors[down_key].float()
out[part] = merged_weight
return out

Expand Down Expand Up @@ -187,7 +201,7 @@ def insert_ziplora_to_unet(
initialize_ziplora_layer_for_inference(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
part="to_out.0",
part="to_out_0",
**kwargs,
)
)
Expand Down

0 comments on commit 86f139d

Please sign in to comment.