Skip to content

Commit 7242b5f

Browse files
FIX Test to ignore warning for enable_lora_hotswap (#12421)
I noticed that the test should be for the option check_compiled="ignore" but it was using check_compiled="warn". This has been fixed, now the correct argument is passed. However, the fact that the test passed means that it was incorrect to begin with. The way that logs are collected does not collect the logger.warning call here (not sure why). To amend this, I'm now using assertNoLogs. With this change, the test correctly fails when the wrong argument is passed.
1 parent b429796 commit 7242b5f

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/models/test_modeling_common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import unittest
2626
import unittest.mock as mock
2727
import uuid
28-
import warnings
2928
from collections import defaultdict
3029
from typing import Dict, List, Optional, Tuple, Union
3130

@@ -2373,14 +2372,15 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
23732372

23742373
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
23752374
# check possibility to ignore the error/warning
2375+
from diffusers.loaders.peft import logger
2376+
23762377
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
23772378
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
23782379
model = self.model_class(**init_dict).to(torch_device)
23792380
model.add_adapter(lora_config)
2380-
with warnings.catch_warnings(record=True) as w:
2381-
warnings.simplefilter("always") # Capture all warnings
2382-
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
2383-
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
2381+
# note: assertNoLogs requires Python 3.10+
2382+
with self.assertNoLogs(logger, level="WARNING"):
2383+
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
23842384

23852385
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
23862386
# check that wrong argument value raises an error

0 commit comments

Comments
 (0)