From 6bd85ec96b302af6798087dd7e9ebe9eaa72753b Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Tue, 9 Aug 2022 15:40:58 +0300 Subject: [PATCH] Fix quant model re-load bug (#978) --- src/sparseml/pytorch/utils/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/utils/model.py b/src/sparseml/pytorch/utils/model.py index 97dcd7adc35..9d676bcc8a4 100644 --- a/src/sparseml/pytorch/utils/model.py +++ b/src/sparseml/pytorch/utils/model.py @@ -80,7 +80,6 @@ def load_model( if path.startswith("zoo:"): path = download_framework_model_by_recipe_type(Model(path)) model_dict = torch.load(path, map_location="cpu") - current_dict = model.state_dict() recipe = model_dict.get("recipe") if recipe: @@ -90,6 +89,7 @@ def load_model( checkpoint_manager = ScheduledModifierManager.from_yaml(recipe) checkpoint_manager.apply_structure(module=model, epoch=epoch) + current_dict = model.state_dict() if "state_dict" in model_dict: model_dict = model_dict["state_dict"]