Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def load_model_dict_into_meta(
empty_state_dict = model.state_dict()

for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
if unexpected_keys is not None and param_name in unexpected_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's better, actually in transformers we rely on unexpected keys

continue

set_module_kwargs = {}
Expand All @@ -260,10 +260,16 @@ def load_model_dict_into_meta(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if param_name in empty_state_dict:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
else:
# hf_quantizer can add parameters that doesn't exist yet in the model and the empty_state_dict
# they will be created in create_quantized_param and hf_quantizer should handle the loading of these parameters
# these parameters will be in the loaded_state_dict from the model file instead when loading a pre_quantized model
old_param = None
Comment on lines +263 to +272
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah indeed this is kind of what we did in _infer_parameter_dtype in transformers


if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
Expand All @@ -279,7 +285,7 @@ def load_model_dict_into_meta(

# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if param_name in empty_state_dict and empty_state_dict[param_name].shape != param.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just add a small comment for that as we will probably refactor the loading at some point to match what we have in transformers

if (
is_quantized
and hf_quantizer.pre_quantized
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,12 +1564,17 @@ def _load_pretrained_model(
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
):
is_quantized = hf_quantizer is not None
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
if is_quantized:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
missing_keys = list(set(expected_keys) - set(loaded_keys))
if hf_quantizer is not None:
if is_quantized:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
if is_quantized:
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_unexpected is not None:
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,28 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li
"""
return missing_keys

def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.
Args:
expected_keys (`list[str]`, *optional*):
The list of the expected keys in the initialized model.
loaded_keys (`list[str]`, *optional*):
The list of the loaded keys in the checkpoint.
"""
return expected_keys

def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.
Args:
unexpected_keys (`list[str]`, *optional*):
The list of the unexpected keys in the checkpoint compared to the state dict of the model
"""
return unexpected_keys

def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case one
Expand Down