diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..6c6257de0523 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -233,7 +233,7 @@ def load_model_dict_into_meta( empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): - if param_name not in empty_state_dict: + if unexpected_keys is not None and param_name in unexpected_keys: continue set_module_kwargs = {} @@ -260,10 +260,16 @@ def load_model_dict_into_meta( # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 - old_param = model - splits = param_name.split(".") - for split in splits: - old_param = getattr(old_param, split) + if param_name in empty_state_dict: + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + else: + # hf_quantizer can add parameters that doesn't exist yet in the model and the empty_state_dict + # they will be created in create_quantized_param and hf_quantizer should handle the loading of these parameters + # these parameters will be in the loaded_state_dict from the model file instead when loading a pre_quantized model + old_param = None if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): old_param = None @@ -279,7 +285,9 @@ def load_model_dict_into_meta( # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied - if empty_state_dict[param_name].shape != param.shape: + # current parameter might not be in the empty_state_dict if the hf_quantizer needs to create it in create_quantized_param + # pass the to be created parameters to create_quantized_param instead + if param_name in empty_state_dict and empty_state_dict[param_name].shape != param.shape: if ( is_quantized and hf_quantizer.pre_quantized diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1af7ba9ac511..cf342f46c741 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1564,12 +1564,17 @@ def _load_pretrained_model( dduf_entries: Optional[Dict[str, DDUFEntry]] = None, is_parallel_loading_enabled: Optional[bool] = False, ): + is_quantized = hf_quantizer is not None model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) + if is_quantized: + expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) missing_keys = list(set(expected_keys) - set(loaded_keys)) - if hf_quantizer is not None: + if is_quantized: missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + if is_quantized: + unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys) # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls._keys_to_ignore_on_load_unexpected is not None: diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 24fc724b4c88..53f6a1b9878e 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -110,6 +110,28 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li """ return missing_keys + def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]: + """ + Override this method if you want to adjust the `update_expected_keys`. + + Args: + expected_keys (`list[str]`, *optional*): + The list of the expected keys in the initialized model. + loaded_keys (`list[str]`, *optional*): + The list of the loaded keys in the checkpoint. + """ + return expected_keys + + def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: + """ + Override this method if you want to adjust the `update_expected_keys`. + + Args: + unexpected_keys (`list[str]`, *optional*): + The list of the unexpected keys in the checkpoint compared to the state dict of the model + """ + return unexpected_keys + def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: """ returns dtypes for modules that are not quantized - used for the computation of the device_map in case one