diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index e06f0a3766e6..575aa515731b 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -210,6 +210,8 @@ def preprocess_state_dict(pretrained_model_name_or_path): manager = ScheduledModifierManager.from_yaml(recipe) modifiers = [m.__class__.__name__ for m in manager.modifiers] is_qat_recipe = "QuantizationModifier" in modifiers + else: + is_qat_recipe = False if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) state_dict = torch.load(archive_file, map_location="cpu")