From c6908be6720d1e2cab51abb10286d16d44bdc8bd Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 21 Apr 2026 20:17:30 +0800 Subject: [PATCH] fix the test_hotswapping_compiled_model_xxx case failure in xpu Signed-off-by: Wang, Yi --- src/diffusers/loaders/peft.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index daa078bc25d5..8c75c5976c24 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -311,12 +311,28 @@ def map_state_dict_for_hotswap(sd): state_dict = map_state_dict_for_hotswap(state_dict) check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) try: - hotswap_adapter_from_state_dict( - model=self, - state_dict=state_dict, - adapter_name=adapter_name, - config=lora_config, - ) + # When enable_lora_hotswap was called, weights were padded to target_rank so that + # torch.compile does not recompile after hotswapping. peft's hotswap_adapter_from_state_dict + # uses swap_tensors for un-compiled models, which replaces the parameter tensor object. + # On backends without CUDA Graphs (e.g. XPU), the compiled inductor code re-fetches + # the real model parameters on every call and its assert_size_stride checks fail because + # the swapped-in tensor has the smaller (un-padded) rank shape. + # Fix: temporarily signal to peft that this model should be treated as compiled so + # it uses the in-place data copy path (which preserves both the tensor object and the + # padded shape). + _supports_compiled_lora_hotswap = getattr(self, "_supports_compiled_lora_hotswap", False) + _orig_compiled_call_impl = getattr(self, "_compiled_call_impl", None) + if _supports_compiled_lora_hotswap and not hasattr(self, "_orig_mod"): + self._compiled_call_impl = True + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=state_dict, + adapter_name=adapter_name, + config=lora_config, + ) + finally: + self._compiled_call_impl = _orig_compiled_call_impl except Exception as e: logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") raise @@ -830,3 +846,6 @@ def enable_lora_hotswap( ) self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} + # Remember that this model was prepared for compiled LoRA hotswapping even after the + # one-shot prepare kwargs are consumed when the first adapter is loaded. + self._supports_compiled_lora_hotswap = True