diff --git a/src/sparsezoo/models/zoo.py b/src/sparsezoo/models/zoo.py index 05575ca2..d7d34f37 100644 --- a/src/sparsezoo/models/zoo.py +++ b/src/sparsezoo/models/zoo.py @@ -541,12 +541,28 @@ def download_recipe_base_framework_files( stub, args = parse_zoo_stub(stub, valid_params=["recipe_type"]) model = Zoo.load_model_from_stub(stub) - if ( - recipe_type in args - and args["recipe_type"] == OptimizationRecipeTypes.TRANSFER_LEARN.value - ): + # check recipe type, default to original, and validate + recipe_type = args.get("recipe_type", OptimizationRecipeTypes.ORIGINAL.value) + valid_recipe_types = list(map(lambda typ: typ.value, OptimizationRecipeTypes)) + if recipe_type not in valid_recipe_types: + raise ValueError( + f"Invalid recipe_type {recipe_type}. " + f"Valid recipe types: {valid_recipe_types}" + ) + + if recipe_type == OptimizationRecipeTypes.TRANSFER_LEARN.value: # return final model's optimized weights for sparse transfer learning - return model.download_framework_files(extensions=extensions) + framework_files = model.download_framework_files(extensions=extensions) + + # download only pre-quantized weights if available + checkpoint_framework_files = [ + framework_file + for framework_file in framework_files + if ".ckpt" in framework_file + ] + + # return non-empty list, preferring filtered list + return checkpoint_framework_files or framework_files else: # search for base model, and return those weights as a starting checkpoint base_model = [ @@ -556,4 +572,16 @@ def download_recipe_base_framework_files( ] if not base_model: raise ValueError(f"Could not find base model for model {model}") - return base_model[0].download_framework_files(extensions=extensions) + framework_files = base_model[0].download_framework_files( + extensions=extensions + ) + + # filter out checkpoint weights if any exist + base_framework_files = [ + framework_file + for framework_file in framework_files + if ".ckpt" not in framework_file + ] + + # return non-empty list, preferring filtered list + return base_framework_files or framework_files