From ac067196deaa311e193312c33663468a66b316ee Mon Sep 17 00:00:00 2001 From: Disty0 Date: Sat, 11 Oct 2025 17:50:06 +0300 Subject: [PATCH 1/3] Add update expected / unexpected keys api to DiffusersQuantizer --- src/diffusers/models/model_loading_utils.py | 17 ++++++++++------ src/diffusers/models/modeling_utils.py | 7 ++++++- src/diffusers/quantizers/base.py | 22 +++++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..9da7dce547bf 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -233,7 +233,7 @@ def load_model_dict_into_meta( empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): - if param_name not in empty_state_dict: + if param_name in unexpected_keys: continue set_module_kwargs = {} @@ -260,10 +260,15 @@ def load_model_dict_into_meta( # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 - old_param = model - splits = param_name.split(".") - for split in splits: - old_param = getattr(old_param, split) + if param_name in empty_state_dict: + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + else: + # hf_quantizer can add parameters that doesn't exist yet + # they will be in the loaded_state_dict when pre_quantized + old_param = None if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): old_param = None @@ -279,7 +284,7 @@ def load_model_dict_into_meta( # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied - if empty_state_dict[param_name].shape != param.shape: + if param_name in empty_state_dict and empty_state_dict[param_name].shape != param.shape: if ( is_quantized and hf_quantizer.pre_quantized diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1af7ba9ac511..cf342f46c741 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1564,12 +1564,17 @@ def _load_pretrained_model( dduf_entries: Optional[Dict[str, DDUFEntry]] = None, is_parallel_loading_enabled: Optional[bool] = False, ): + is_quantized = hf_quantizer is not None model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) + if is_quantized: + expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) missing_keys = list(set(expected_keys) - set(loaded_keys)) - if hf_quantizer is not None: + if is_quantized: missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + if is_quantized: + unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys) # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls._keys_to_ignore_on_load_unexpected is not None: diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 24fc724b4c88..53f6a1b9878e 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -110,6 +110,28 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li """ return missing_keys + def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]: + """ + Override this method if you want to adjust the `update_expected_keys`. + + Args: + expected_keys (`list[str]`, *optional*): + The list of the expected keys in the initialized model. + loaded_keys (`list[str]`, *optional*): + The list of the loaded keys in the checkpoint. + """ + return expected_keys + + def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: + """ + Override this method if you want to adjust the `update_expected_keys`. + + Args: + unexpected_keys (`list[str]`, *optional*): + The list of the unexpected keys in the checkpoint compared to the state dict of the model + """ + return unexpected_keys + def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: """ returns dtypes for modules that are not quantized - used for the computation of the device_map in case one From ed8be9735f04dc13e1efcf4f889d43ffee906521 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 13 Oct 2025 14:44:46 +0300 Subject: [PATCH 2/3] fix unexpected_keys is None case and add better comment --- src/diffusers/models/model_loading_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 9da7dce547bf..9e7012f69210 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -233,7 +233,7 @@ def load_model_dict_into_meta( empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): - if param_name in unexpected_keys: + if unexpected_keys is not None and param_name in unexpected_keys: continue set_module_kwargs = {} @@ -266,8 +266,9 @@ def load_model_dict_into_meta( for split in splits: old_param = getattr(old_param, split) else: - # hf_quantizer can add parameters that doesn't exist yet - # they will be in the loaded_state_dict when pre_quantized + # hf_quantizer can add parameters that doesn't exist yet in the model and the empty_state_dict + # they will be created in create_quantized_param and hf_quantizer should handle the loading of these parameters + # these parameters will be in the loaded_state_dict from the model file instead when loading a pre_quantized model old_param = None if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): From 745b041409ee096bb29028beb93dc9f4c54e5e49 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 14 Oct 2025 00:35:04 +0300 Subject: [PATCH 3/3] add hf_quantizer comment on shape check --- src/diffusers/models/model_loading_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 9e7012f69210..be093e2a4367 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -285,6 +285,8 @@ def load_model_dict_into_meta( # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied + # current parameter might not be in the empty_state_dict if the hf_quantizer needs to create it in create_quantized_param + # pass the to be created parameters to create_quantized_param instead if param_name in empty_state_dict and empty_state_dict[param_name].shape != param.shape: if ( is_quantized