diff --git a/integrations/keras/prune_resnet20.py b/integrations/keras/prune_resnet20.py index b9157f0faa7..89b468d48fd 100644 --- a/integrations/keras/prune_resnet20.py +++ b/integrations/keras/prune_resnet20.py @@ -65,7 +65,9 @@ def download_model_and_recipe(root_dir: str): Download pretrained model and a pruning recipe """ model_dir = os.path.join(root_dir, "resnet20_v1") - zoo_model = Zoo.load_model( + + # Load base model to prune + base_zoo_model = Zoo.load_model( domain="cv", sub_domain="classification", architecture="resnet_v1", @@ -74,18 +76,21 @@ def download_model_and_recipe(root_dir: str): repo="sparseml", dataset="cifar_10", training_scheme=None, - optim_name="pruned", - optim_category="conservative", + optim_name="base", + optim_category=None, optim_target=None, override_parent_path=model_dir, ) - zoo_model.download() - model_file_path = zoo_model.framework_files[0].downloaded_path() + base_zoo_model.download() + model_file_path = base_zoo_model.framework_files[0].downloaded_path() if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"): raise RuntimeError("Model file not found: {}".format(model_file_path)) - recipe_file_path = zoo_model.recipes[0].downloaded_path() - if not os.path.exists(recipe_file_path): - raise RuntimeError("Recipe file not found: {}".format(recipe_file_path)) + + # Simply use the recipe stub + recipe_file_path = ( + "zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative" + ) + return model_file_path, recipe_file_path @@ -132,6 +137,7 @@ def main(): (X_train, y_train), (X_test, y_test) = load_and_normalize_cifar10() model_file_path, recipe_file_path = download_model_and_recipe(root_dir) + print("Load pretrained model") base_model = tf.keras.models.load_model(model_file_path) base_model.summary()