diff --git a/src/sparsezoo/models/zoo.py b/src/sparsezoo/models/zoo.py index d7d34f37..6006ad92 100644 --- a/src/sparsezoo/models/zoo.py +++ b/src/sparsezoo/models/zoo.py @@ -507,20 +507,17 @@ def download_recipe_from_stub( :return: file path of the downloaded recipe for that model """ stub, args = parse_zoo_stub(stub, valid_params=["recipe_type"]) + recipe_type = _get_stub_args_recipe_type(args) model = Zoo.load_model_from_stub(stub) - # parse recipe type and find matching recipe - recipe_type = ( - args["recipe_type"] - if "recipe_type" in args - else OptimizationRecipeTypes.ORIGINAL.value - ) for recipe in model.recipes: if recipe.recipe_type == recipe_type: return recipe.downloaded_path() + found_recipe_types = [recipe.recipe_type for recipe in model.recipes] raise RuntimeError( - f"No recipe with recipe_type {recipe_type} found for model {model}" + f"No recipe with recipe_type {recipe_type} found for model {model}. " + f"Found {len(model.recipes)} recipes with recipe types {found_recipe_types}" ) @staticmethod @@ -539,17 +536,9 @@ def download_recipe_base_framework_files( base weights of this recipe """ stub, args = parse_zoo_stub(stub, valid_params=["recipe_type"]) + recipe_type = _get_stub_args_recipe_type(args) model = Zoo.load_model_from_stub(stub) - # check recipe type, default to original, and validate - recipe_type = args.get("recipe_type", OptimizationRecipeTypes.ORIGINAL.value) - valid_recipe_types = list(map(lambda typ: typ.value, OptimizationRecipeTypes)) - if recipe_type not in valid_recipe_types: - raise ValueError( - f"Invalid recipe_type {recipe_type}. " - f"Valid recipe types: {valid_recipe_types}" - ) - if recipe_type == OptimizationRecipeTypes.TRANSFER_LEARN.value: # return final model's optimized weights for sparse transfer learning framework_files = model.download_framework_files(extensions=extensions) @@ -585,3 +574,17 @@ def download_recipe_base_framework_files( # return non-empty list, preferring filtered list return base_framework_files or framework_files + + +def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> str: + # check recipe type, default to original, and validate + recipe_type = stub_args.get("recipe_type", OptimizationRecipeTypes.ORIGINAL.value) + + # validate + valid_recipe_types = list(map(lambda typ: typ.value, OptimizationRecipeTypes)) + if recipe_type not in valid_recipe_types: + raise ValueError( + f"Invalid recipe_type: '{recipe_type}'. " + f"Valid recipe types: {valid_recipe_types}" + ) + return recipe_type