Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 39 additions & 47 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,70 +558,62 @@ def assign_remaining_weights(assignments, source):
ait_sd[target_key] = value

if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
Comment on lines -561 to -562
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this going away? From what I understand, assign_remaining_weights() is doing the same thing as what the changes in this PR are doing no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't converting the alphas so instead of making assign_remaining_weights do the alpha I thought to use the other function (_convert_to_ai_toolkit) here. Maybe I'm missing the key reason this was being used instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. That makes sense.

Could we run the integration tests for Flux LoRA before we merge this PR?

class FluxLoRAIntegrationTests(unittest.TestCase):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't run all the tests due to my PC and even these I had to disable all the skips for this to run. But I was able to complete the test for the new test. If you need more of the tests run I can try though.

RUN_SLOW=1 RUN_NIGHTLY=1 python -m pytest tests/lora/test_lora_layers_flux.py -k test_flux_kohya_embedders_conversion
platform linux -- Python 3.13.5, pytest-8.4.2, pluggy-1.6.0
rootdir: /home/rockerboo/code/others/diffusers
configfile: pyproject.toml
plugins: timeout-2.4.0, requests-mock-1.10.0, xdist-3.8.0
collected 122 items / 121 deselected / 1 selected

tests/lora/test_lora_layers_flux.py .

1 passed, 121 deselected, 3 warnings in 555.38s (0:09:15)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant the other tests that we have in the test suite. Let me try to run them.

(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_in_layer",
"time_text_embed.guidance_embedder.linear_1",
)

_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_out_layer",
"time_text_embed.guidance_embedder.linear_2",
)

if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_img_in",
"x_embedder",
)

if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_txt_in",
"context_embedder",
)

if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_in_layer",
"time_text_embed.timestep_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_out_layer",
"time_text_embed.timestep_embedder.linear_2",
)

if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_in_layer",
"time_text_embed.text_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_out_layer",
"time_text_embed.text_embedder.linear_2",
)

if any("final_layer" in k for k in sds_sd):
Expand Down
7 changes: 7 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,13 @@ def test_flux_kohya_with_text_encoder(self):

assert max_diff < 1e-3

def test_flux_kohya_embedders_conversion(self):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()

assert True

def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
Expand Down