Skip to content

Commit c67f59f

Browse files
authored
Merge branch 'main' into fix-flux-mellon
2 parents 9dbed17 + 941ac9c commit c67f59f

File tree

7 files changed

+57
-6
lines changed

7 files changed

+57
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
# "Jinja2",
2626
# "peft>=0.11.1",
2727
# "sentencepiece",
28+
# "torchvision",
29+
# "datasets",
30+
# "bitsandbytes",
31+
# "prodigyopt",
2832
# ]
2933
# ///
3034

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
# "Jinja2",
2626
# "peft>=0.11.1",
2727
# "sentencepiece",
28+
# "torchvision",
29+
# "datasets",
30+
# "bitsandbytes",
31+
# "prodigyopt",
2832
# ]
2933
# ///
3034

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
# /// script
18+
# dependencies = [
19+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
20+
# "torch>=2.0.0",
21+
# "accelerate>=0.31.0",
22+
# "transformers>=4.41.2",
23+
# "ftfy",
24+
# "tensorboard",
25+
# "Jinja2",
26+
# "peft>=0.11.1",
27+
# "sentencepiece",
28+
# "torchvision",
29+
# "datasets",
30+
# "bitsandbytes",
31+
# "prodigyopt",
32+
# ]
33+
# ///
34+
1735
import argparse
1836
import copy
1937
import itertools

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=0.31.0",
21+
# "transformers>=4.41.2",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.11.1",
26+
# "sentencepiece",
27+
# "torchvision",
28+
# "datasets",
29+
# "bitsandbytes",
30+
# "prodigyopt",
31+
# ]
32+
# ///
33+
1634
import argparse
1735
import copy
1836
import itertools

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
# "Jinja2",
2626
# "peft>=0.14.0",
2727
# "sentencepiece",
28+
# "torchvision",
29+
# "datasets",
30+
# "bitsandbytes",
31+
# "prodigyopt",
2832
# ]
2933
# ///
3034

src/diffusers/models/attention_dispatch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2121

2222
import torch
23-
import torch.distributed._functional_collectives as funcol
23+
24+
25+
if torch.distributed.is_available():
26+
import torch.distributed._functional_collectives as funcol
2427

2528
from ..utils import (
2629
get_logger,

tests/models/test_modeling_common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import unittest
2626
import unittest.mock as mock
2727
import uuid
28-
import warnings
2928
from collections import defaultdict
3029
from typing import Dict, List, Optional, Tuple, Union
3130

@@ -2373,14 +2372,15 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
23732372

23742373
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
23752374
# check possibility to ignore the error/warning
2375+
from diffusers.loaders.peft import logger
2376+
23762377
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
23772378
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
23782379
model = self.model_class(**init_dict).to(torch_device)
23792380
model.add_adapter(lora_config)
2380-
with warnings.catch_warnings(record=True) as w:
2381-
warnings.simplefilter("always") # Capture all warnings
2382-
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
2383-
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
2381+
# note: assertNoLogs requires Python 3.10+
2382+
with self.assertNoLogs(logger, level="WARNING"):
2383+
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
23842384

23852385
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
23862386
# check that wrong argument value raises an error

0 commit comments

Comments
 (0)