From fc63300808b763eff7bf2a666165dad961dbf9a4 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 31 Mar 2021 12:08:57 -0400 Subject: [PATCH 1/3] Retrain model to work with older TF versions --- integrations/keras/prune_resnet20.py | 31 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/integrations/keras/prune_resnet20.py b/integrations/keras/prune_resnet20.py index b9157f0faa7..35e0414d4e3 100644 --- a/integrations/keras/prune_resnet20.py +++ b/integrations/keras/prune_resnet20.py @@ -65,7 +65,9 @@ def download_model_and_recipe(root_dir: str): Download pretrained model and a pruning recipe """ model_dir = os.path.join(root_dir, "resnet20_v1") - zoo_model = Zoo.load_model( + + # Load base model to prune + base_zoo_model = Zoo.load_model( domain="cv", sub_domain="classification", architecture="resnet_v1", @@ -74,16 +76,33 @@ def download_model_and_recipe(root_dir: str): repo="sparseml", dataset="cifar_10", training_scheme=None, - optim_name="pruned", - optim_category="conservative", + optim_name="base", + optim_category=None, optim_target=None, override_parent_path=model_dir, ) - zoo_model.download() - model_file_path = zoo_model.framework_files[0].downloaded_path() + base_zoo_model.download() + model_file_path = base_zoo_model.framework_files[0].downloaded_path() if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"): raise RuntimeError("Model file not found: {}".format(model_file_path)) - recipe_file_path = zoo_model.recipes[0].downloaded_path() + + # Load recipe to use + pruned_zoo_model = Zoo.load_model( + domain="cv", + sub_domain="classification", + architecture="resnet_v1", + sub_architecture=20, + framework="keras", + repo="sparseml", + dataset="cifar_10", + training_scheme=None, + optim_name="pruned", + optim_category="conservative", + optim_target=None, + override_parent_path=model_dir, + ) + pruned_zoo_model.download() + recipe_file_path = pruned_zoo_model.recipes[0].downloaded_path() if not os.path.exists(recipe_file_path): raise RuntimeError("Recipe file not found: {}".format(recipe_file_path)) return model_file_path, recipe_file_path From 64755caec2068050297860016b8e6695b9ad78b2 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 31 Mar 2021 13:09:56 -0400 Subject: [PATCH 2/3] Load pruned recipe --- integrations/keras/prune_resnet20.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/integrations/keras/prune_resnet20.py b/integrations/keras/prune_resnet20.py index 35e0414d4e3..385028896c9 100644 --- a/integrations/keras/prune_resnet20.py +++ b/integrations/keras/prune_resnet20.py @@ -87,7 +87,7 @@ def download_model_and_recipe(root_dir: str): raise RuntimeError("Model file not found: {}".format(model_file_path)) # Load recipe to use - pruned_zoo_model = Zoo.load_model( + pruned_recipe = Zoo.search_recipes( domain="cv", sub_domain="classification", architecture="resnet_v1", @@ -101,8 +101,7 @@ def download_model_and_recipe(root_dir: str): optim_target=None, override_parent_path=model_dir, ) - pruned_zoo_model.download() - recipe_file_path = pruned_zoo_model.recipes[0].downloaded_path() + recipe_file_path = pruned_recipe[0].downloaded_path() if not os.path.exists(recipe_file_path): raise RuntimeError("Recipe file not found: {}".format(recipe_file_path)) return model_file_path, recipe_file_path From 1ef3362a01d8d68f2f7dc2b1343724f41260014a Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 31 Mar 2021 13:46:43 -0400 Subject: [PATCH 3/3] Using recipe stub --- integrations/keras/prune_resnet20.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/integrations/keras/prune_resnet20.py b/integrations/keras/prune_resnet20.py index 385028896c9..89b468d48fd 100644 --- a/integrations/keras/prune_resnet20.py +++ b/integrations/keras/prune_resnet20.py @@ -86,24 +86,11 @@ def download_model_and_recipe(root_dir: str): if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"): raise RuntimeError("Model file not found: {}".format(model_file_path)) - # Load recipe to use - pruned_recipe = Zoo.search_recipes( - domain="cv", - sub_domain="classification", - architecture="resnet_v1", - sub_architecture=20, - framework="keras", - repo="sparseml", - dataset="cifar_10", - training_scheme=None, - optim_name="pruned", - optim_category="conservative", - optim_target=None, - override_parent_path=model_dir, + # Simply use the recipe stub + recipe_file_path = ( + "zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative" ) - recipe_file_path = pruned_recipe[0].downloaded_path() - if not os.path.exists(recipe_file_path): - raise RuntimeError("Recipe file not found: {}".format(recipe_file_path)) + return model_file_path, recipe_file_path @@ -150,6 +137,7 @@ def main(): (X_train, y_train), (X_test, y_test) = load_and_normalize_cifar10() model_file_path, recipe_file_path = download_model_and_recipe(root_dir) + print("Load pretrained model") base_model = tf.keras.models.load_model(model_file_path) base_model.summary()