|
25 | 25 | import unittest
|
26 | 26 | import unittest.mock as mock
|
27 | 27 | import uuid
|
28 |
| -import warnings |
29 | 28 | from collections import defaultdict
|
30 | 29 | from typing import Dict, List, Optional, Tuple, Union
|
31 | 30 |
|
@@ -2373,14 +2372,15 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
|
2373 | 2372 |
|
2374 | 2373 | def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
2375 | 2374 | # check possibility to ignore the error/warning
|
| 2375 | + from diffusers.loaders.peft import logger |
| 2376 | + |
2376 | 2377 | lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
2377 | 2378 | init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
2378 | 2379 | model = self.model_class(**init_dict).to(torch_device)
|
2379 | 2380 | model.add_adapter(lora_config)
|
2380 |
| - with warnings.catch_warnings(record=True) as w: |
2381 |
| - warnings.simplefilter("always") # Capture all warnings |
2382 |
| - model.enable_lora_hotswap(target_rank=32, check_compiled="warn") |
2383 |
| - self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") |
| 2381 | + # note: assertNoLogs requires Python 3.10+ |
| 2382 | + with self.assertNoLogs(logger, level="WARNING"): |
| 2383 | + model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") |
2384 | 2384 |
|
2385 | 2385 | def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
2386 | 2386 | # check that wrong argument value raises an error
|
|
0 commit comments