Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import itertools
import os
import re
from functools import partial
from functools import lru_cache, partial
from typing import Any, Callable, List, Optional, Tuple, Union

import safetensors
Expand Down Expand Up @@ -75,7 +75,8 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].device


def get_parameter_dtype(parameter: torch.nn.Module):
@lru_cache(None)
def _get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
Expand All @@ -97,6 +98,16 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].dtype


def get_parameter_dtype(parameter: torch.nn.Module):
try:
return _get_parameter_dtype(parameter)
except TypeError:
# For being backwards compatible and supporting torch modules
# that might not be hashable (e.g. custom modules), we fallback
# into the non-cached version.
return _get_parameter_dtype.__wrapped__(parameter)


def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
Expand Down