Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions src/sparsezoo/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,28 @@ def download_recipe_base_framework_files(
stub, args = parse_zoo_stub(stub, valid_params=["recipe_type"])
model = Zoo.load_model_from_stub(stub)

if (
recipe_type in args
and args["recipe_type"] == OptimizationRecipeTypes.TRANSFER_LEARN.value
):
# check recipe type, default to original, and validate
recipe_type = args.get("recipe_type", OptimizationRecipeTypes.ORIGINAL.value)
valid_recipe_types = list(map(lambda typ: typ.value, OptimizationRecipeTypes))
if recipe_type not in valid_recipe_types:
raise ValueError(
f"Invalid recipe_type {recipe_type}. "
f"Valid recipe types: {valid_recipe_types}"
)

if recipe_type == OptimizationRecipeTypes.TRANSFER_LEARN.value:
# return final model's optimized weights for sparse transfer learning
return model.download_framework_files(extensions=extensions)
framework_files = model.download_framework_files(extensions=extensions)

# download only pre-quantized weights if available
checkpoint_framework_files = [
framework_file
for framework_file in framework_files
if ".ckpt" in framework_file
]

# return non-empty list, preferring filtered list
return checkpoint_framework_files or framework_files
else:
# search for base model, and return those weights as a starting checkpoint
base_model = [
Expand All @@ -556,4 +572,16 @@ def download_recipe_base_framework_files(
]
if not base_model:
raise ValueError(f"Could not find base model for model {model}")
return base_model[0].download_framework_files(extensions=extensions)
framework_files = base_model[0].download_framework_files(
extensions=extensions
)

# filter out checkpoint weights if any exist
base_framework_files = [
framework_file
for framework_file in framework_files
if ".ckpt" not in framework_file
]

# return non-empty list, preferring filtered list
return base_framework_files or framework_files