Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/sparsezoo/objects/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class RecipeTypes(Enum):
"""

ORIGINAL = "original"
SPARSE = "sparse"
TRANSFER = "transfer"
TRANSFER_LEARN = "transfer_learn"


Expand Down Expand Up @@ -455,9 +457,6 @@ def search_sparse_recipes(
"""
from sparsezoo.objects.model import Model

if isinstance(recipe_type, str):
recipe_type = RecipeTypes(recipe_type).value

if not isinstance(model, Model):
model = Model.load_model_from_stub(model)

Expand Down Expand Up @@ -508,15 +507,21 @@ def recipe_type_original(self) -> bool:
:return: True if this is the original recipe that created the
model, False otherwise
"""
return self.recipe_type == RecipeTypes.ORIGINAL.value
return any(
self.recipe_type.startswith(start)
for start in [RecipeTypes.ORIGINAL.value, RecipeTypes.SPARSE.value]
)

@property
def recipe_type_transfer_learn(self) -> bool:
"""
:return: True if this is a recipe for transfer learning from the
created model, False otherwise
"""
return self.recipe_type == RecipeTypes.TRANSFER_LEARN.value
return any(
self.recipe_type.startswith(start)
for start in [RecipeTypes.TRANSFER.value, RecipeTypes.TRANSFER_LEARN.value]
)

@property
def display_name(self):
Expand Down Expand Up @@ -653,15 +658,17 @@ def download_base_framework_files(
return base_framework_files or framework_files


def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> str:
def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> Optional[str]:
# check recipe type, default to original, and validate
recipe_type = stub_args.get("recipe_type")

# validate
valid_recipe_types = list(map(lambda typ: typ.value, RecipeTypes))
if recipe_type not in valid_recipe_types and recipe_type is not None:

if recipe_type is not None and not any(
recipe_type.startswith(start) for start in valid_recipe_types
):
raise ValueError(
f"Invalid recipe_type: '{recipe_type}'. "
f"Valid recipe types: {valid_recipe_types}"
f"Valid recipes must start with one of: {valid_recipe_types}"
)

return recipe_type