Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_keys to drop certain keys from the LoRA network #964

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ def create_network(
conv_block_dims = None
conv_block_alphas = None

drop_keys = kwargs.get("drop_keys", None)
if drop_keys is not None:
drop_keys = drop_keys.split(',')

# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
Expand All @@ -480,6 +484,7 @@ def create_network(
block_alphas=block_alphas,
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
drop_keys=drop_keys,
varbose=True,
)

Expand Down Expand Up @@ -764,6 +769,7 @@ def __init__(
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
drop_keys: Optional[List[str]] = None,
varbose: Optional[bool] = False,
) -> None:
"""
Expand All @@ -784,6 +790,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.drop_keys = drop_keys

if modules_dim is not None:
print(f"create LoRA network from weights")
Expand All @@ -801,6 +808,9 @@ def __init__(
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")

if self.drop_keys:
print(f"Drop keys: {self.drop_keys}")

# create module instances
def create_modules(
is_unet: bool,
Expand Down Expand Up @@ -830,6 +840,12 @@ def create_modules(
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")

if self.drop_keys:
for key in self.drop_keys:
if key in lora_name:
skipped.append(lora_name)
continue

dim = None
alpha = None

Expand Down