Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,28 @@ def map_state_dict_for_hotswap(sd):
state_dict = map_state_dict_for_hotswap(state_dict)
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
# When enable_lora_hotswap was called, weights were padded to target_rank so that
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep things like this, how about turning this into a context manager to keep the code inside of load_lora_adapter more readable? It would also allow to re-use the context manager if it's ever needed.

# torch.compile does not recompile after hotswapping. peft's hotswap_adapter_from_state_dict
# uses swap_tensors for un-compiled models, which replaces the parameter tensor object.
# On backends without CUDA Graphs (e.g. XPU), the compiled inductor code re-fetches
# the real model parameters on every call and its assert_size_stride checks fail because
# the swapped-in tensor has the smaller (un-padded) rank shape.
# Fix: temporarily signal to peft that this model should be treated as compiled so
# it uses the in-place data copy path (which preserves both the tensor object and the
# padded shape).
_supports_compiled_lora_hotswap = getattr(self, "_supports_compiled_lora_hotswap", False)
_orig_compiled_call_impl = getattr(self, "_compiled_call_impl", None)
if _supports_compiled_lora_hotswap and not hasattr(self, "_orig_mod"):
self._compiled_call_impl = True
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
finally:
self._compiled_call_impl = _orig_compiled_call_impl
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
raise
Expand Down Expand Up @@ -830,3 +846,6 @@ def enable_lora_hotswap(
)

self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
# Remember that this model was prepared for compiled LoRA hotswapping even after the
# one-shot prepare kwargs are consumed when the first adapter is loaded.
self._supports_compiled_lora_hotswap = True
Loading