From 61b4064443132d712710231ba21b52e4bbf8edde Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Sep 2024 01:09:44 +0200 Subject: [PATCH 1/2] up --- src/diffusers/models/modeling_utils.py | 30 +++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index cfe692dcc54a..7031c2cfcec1 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -93,24 +93,20 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: try: - params = tuple(parameter.parameters()) - if len(params) > 0: - return params[0].dtype - - buffers = tuple(parameter.buffers()) - if len(buffers) > 0: - return buffers[0].dtype - + return next(parameter.parameters()).dtype except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype + try: + return next(parameter.buffers()).dtype + except StopIteration: + # Forch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype class ModelMixin(torch.nn.Module, PushToHubMixin): From e2e824d5429d12cf5860829672b42e5204e0cfc9 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 24 Sep 2024 13:28:19 -1000 Subject: [PATCH 2/2] Update src/diffusers/models/modeling_utils.py Co-authored-by: Aryan --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 7031c2cfcec1..9e0c50e8b37b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -98,7 +98,7 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: try: return next(parameter.buffers()).dtype except StopIteration: - # Forch.nn.DataParallel compatibility in PyTorch 1.5 + # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]