diff --git a/src/sparsezoo/objects/recipe.py b/src/sparsezoo/objects/recipe.py index 213ff74f..138ac182 100644 --- a/src/sparsezoo/objects/recipe.py +++ b/src/sparsezoo/objects/recipe.py @@ -50,6 +50,8 @@ class RecipeTypes(Enum): """ ORIGINAL = "original" + SPARSE = "sparse" + TRANSFER = "transfer" TRANSFER_LEARN = "transfer_learn" @@ -455,9 +457,6 @@ def search_sparse_recipes( """ from sparsezoo.objects.model import Model - if isinstance(recipe_type, str): - recipe_type = RecipeTypes(recipe_type).value - if not isinstance(model, Model): model = Model.load_model_from_stub(model) @@ -508,7 +507,10 @@ def recipe_type_original(self) -> bool: :return: True if this is the original recipe that created the model, False otherwise """ - return self.recipe_type == RecipeTypes.ORIGINAL.value + return any( + self.recipe_type.startswith(start) + for start in [RecipeTypes.ORIGINAL.value, RecipeTypes.SPARSE.value] + ) @property def recipe_type_transfer_learn(self) -> bool: @@ -516,7 +518,10 @@ def recipe_type_transfer_learn(self) -> bool: :return: True if this is a recipe for transfer learning from the created model, False otherwise """ - return self.recipe_type == RecipeTypes.TRANSFER_LEARN.value + return any( + self.recipe_type.startswith(start) + for start in [RecipeTypes.TRANSFER.value, RecipeTypes.TRANSFER_LEARN.value] + ) @property def display_name(self): @@ -653,15 +658,17 @@ def download_base_framework_files( return base_framework_files or framework_files -def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> str: +def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> Optional[str]: # check recipe type, default to original, and validate recipe_type = stub_args.get("recipe_type") - - # validate valid_recipe_types = list(map(lambda typ: typ.value, RecipeTypes)) - if recipe_type not in valid_recipe_types and recipe_type is not None: + + if recipe_type is not None and not any( + recipe_type.startswith(start) for start in valid_recipe_types + ): raise ValueError( f"Invalid recipe_type: '{recipe_type}'. " - f"Valid recipe types: {valid_recipe_types}" + f"Valid recipes must start with one of: {valid_recipe_types}" ) + return recipe_type