diff --git a/src/sparseml/pytorch/utils/model.py b/src/sparseml/pytorch/utils/model.py index 97dcd7adc35..9d676bcc8a4 100644 --- a/src/sparseml/pytorch/utils/model.py +++ b/src/sparseml/pytorch/utils/model.py @@ -80,7 +80,6 @@ def load_model( if path.startswith("zoo:"): path = download_framework_model_by_recipe_type(Model(path)) model_dict = torch.load(path, map_location="cpu") - current_dict = model.state_dict() recipe = model_dict.get("recipe") if recipe: @@ -90,6 +89,7 @@ def load_model( checkpoint_manager = ScheduledModifierManager.from_yaml(recipe) checkpoint_manager.apply_structure(module=model, epoch=epoch) + current_dict = model.state_dict() if "state_dict" in model_dict: model_dict = model_dict["state_dict"]