Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions src/sparsezoo/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,20 +507,17 @@ def download_recipe_from_stub(
:return: file path of the downloaded recipe for that model
"""
stub, args = parse_zoo_stub(stub, valid_params=["recipe_type"])
recipe_type = _get_stub_args_recipe_type(args)
model = Zoo.load_model_from_stub(stub)

# parse recipe type and find matching recipe
recipe_type = (
args["recipe_type"]
if "recipe_type" in args
else OptimizationRecipeTypes.ORIGINAL.value
)
for recipe in model.recipes:
if recipe.recipe_type == recipe_type:
return recipe.downloaded_path()

found_recipe_types = [recipe.recipe_type for recipe in model.recipes]
raise RuntimeError(
f"No recipe with recipe_type {recipe_type} found for model {model}"
f"No recipe with recipe_type {recipe_type} found for model {model}. "
f"Found {len(model.recipes)} recipes with recipe types {found_recipe_types}"
)

@staticmethod
Expand All @@ -539,17 +536,9 @@ def download_recipe_base_framework_files(
base weights of this recipe
"""
stub, args = parse_zoo_stub(stub, valid_params=["recipe_type"])
recipe_type = _get_stub_args_recipe_type(args)
model = Zoo.load_model_from_stub(stub)

# check recipe type, default to original, and validate
recipe_type = args.get("recipe_type", OptimizationRecipeTypes.ORIGINAL.value)
valid_recipe_types = list(map(lambda typ: typ.value, OptimizationRecipeTypes))
if recipe_type not in valid_recipe_types:
raise ValueError(
f"Invalid recipe_type {recipe_type}. "
f"Valid recipe types: {valid_recipe_types}"
)

if recipe_type == OptimizationRecipeTypes.TRANSFER_LEARN.value:
# return final model's optimized weights for sparse transfer learning
framework_files = model.download_framework_files(extensions=extensions)
Expand Down Expand Up @@ -585,3 +574,17 @@ def download_recipe_base_framework_files(

# return non-empty list, preferring filtered list
return base_framework_files or framework_files


def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> str:
# check recipe type, default to original, and validate
recipe_type = stub_args.get("recipe_type", OptimizationRecipeTypes.ORIGINAL.value)

# validate
valid_recipe_types = list(map(lambda typ: typ.value, OptimizationRecipeTypes))
if recipe_type not in valid_recipe_types:
raise ValueError(
f"Invalid recipe_type: '{recipe_type}'. "
f"Valid recipe types: {valid_recipe_types}"
)
return recipe_type