Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ff3d380
fix: parse lora_alpha correctly
sayakpaul Dec 11, 2023
79b1637
fix
sayakpaul Dec 11, 2023
20fac7b
better conditioning
sayakpaul Dec 11, 2023
981ea82
assertion
sayakpaul Dec 11, 2023
cf132fb
debug
sayakpaul Dec 11, 2023
e4c00bc
debug
sayakpaul Dec 11, 2023
3b27b23
dehug
sayakpaul Dec 11, 2023
0d08249
ifx?
sayakpaul Dec 11, 2023
41b9cd8
fix?
sayakpaul Dec 11, 2023
b868e8a
ifx
sayakpaul Dec 11, 2023
c341111
ifx
sayakpaul Dec 11, 2023
a2792cd
unwrap
sayakpaul Dec 11, 2023
9ecb271
unwrap
sayakpaul Dec 11, 2023
32212b6
json unwrap
sayakpaul Dec 11, 2023
ed333f0
remove print
sayakpaul Dec 11, 2023
fdb1146
Empty-Commit
sayakpaul Dec 11, 2023
bcf0f4a
fix
sayakpaul Dec 11, 2023
24cb282
fix
sayakpaul Dec 11, 2023
49a0f3a
Merge branch 'main' into fix/lora-loading
sayakpaul Dec 15, 2023
f4adaae
move config related stuff in a separate utility.
sayakpaul Dec 15, 2023
57a16f3
fix: import error
sayakpaul Dec 15, 2023
d24e7d3
debug
sayakpaul Dec 15, 2023
09618d0
remove print
sayakpaul Dec 15, 2023
ec9df6f
simplify condition.
sayakpaul Dec 15, 2023
16ac1b2
propagate changes to sd dreambooth lora.
sayakpaul Dec 15, 2023
ece6d89
propagate to sd t2i lora fine-tuning
sayakpaul Dec 15, 2023
8c98a18
propagate to sdxl t2i lora fine-tuning
sayakpaul Dec 15, 2023
765fef7
add: doc strings.
sayakpaul Dec 15, 2023
f145d48
add test
sayakpaul Dec 15, 2023
5d04eeb
fix attribute access.
sayakpaul Dec 15, 2023
255adf0
Merge branch 'main' into fix/lora-loading
sayakpaul Dec 15, 2023
f9c7b32
Merge branch 'main' into fix/lora-loading
sayakpaul Dec 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,11 +880,16 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None

unet_lora_config = None
text_encoder_lora_config = None

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand All @@ -895,6 +900,8 @@ def save_model_hook(models, weights, output_dir):
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)

def load_model_hook(models, input_dir):
Expand All @@ -911,10 +918,12 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down Expand Up @@ -1315,17 +1324,22 @@ def compute_text_embeddings(prompt):
unet = unet.to(torch.float32)

unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]

if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_lora_config = text_encoder.peft_config["default"]
else:
text_encoder_state_dict = None
text_encoder_lora_config = None

LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)

# Final inference
Expand Down
28 changes: 24 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,13 +1033,20 @@ def save_model_hook(models, weights, output_dir):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None

unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand All @@ -1051,6 +1058,9 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)

def load_model_hook(models, input_dir):
Expand All @@ -1070,17 +1080,19 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down Expand Up @@ -1616,21 +1628,29 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None

StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)

# Final inference
Expand Down
4 changes: 4 additions & 0 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,12 @@ def collate_fn(examples):
accelerator.save_state(save_path)

unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
unet_lora_config=unet_lora_config,
safe_serialization=True,
)

Expand Down Expand Up @@ -898,10 +900,12 @@ def collate_fn(examples):
unet = unet.to(torch.float32)

unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
unet_lora_config=unet_lora_config,
)

if args.push_to_hub:
Expand Down
29 changes: 25 additions & 4 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,13 +682,20 @@ def save_model_hook(models, weights, output_dir):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None

unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand All @@ -700,6 +707,9 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)

def load_model_hook(models, input_dir):
Expand All @@ -719,17 +729,19 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down Expand Up @@ -1194,22 +1206,31 @@ def compute_time_ids(original_size, crops_coords_top_left):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)

text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)

text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None

StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)

del unet
Expand Down
Loading