From 66116d7c34ffadf3d271a3d9b184ef2fa934403c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 7 Jul 2022 19:11:29 +0200 Subject: [PATCH] fix loading from pretrained for sharded model with `torch_dtype="auto" --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e1621c6e5a2ba..e8021eea73534 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2037,7 +2037,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif not is_sharded: torch_dtype = get_state_dict_dtype(state_dict) else: - one_state_dict = load_state_dict(resolved_archive_file) + one_state_dict = load_state_dict(resolved_archive_file[0]) torch_dtype = get_state_dict_dtype(one_state_dict) del one_state_dict # free CPU memory else: