diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 67746ebacef2..560cdebd0cc9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -18,7 +18,7 @@ import itertools import os import re -from functools import partial +from functools import lru_cache, partial from typing import Any, Callable, List, Optional, Tuple, Union import safetensors @@ -75,7 +75,8 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].device -def get_parameter_dtype(parameter: torch.nn.Module): +@lru_cache(None) +def _get_parameter_dtype(parameter: torch.nn.Module): try: params = tuple(parameter.parameters()) if len(params) > 0: @@ -97,6 +98,16 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].dtype +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return _get_parameter_dtype(parameter) + except TypeError: + # For being backwards compatible and supporting torch modules + # that might not be hashable (e.g. custom modules), we fallback + # into the non-cached version. + return _get_parameter_dtype.__wrapped__(parameter) + + def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise.