Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions scripts/pytorch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@
--checkpoint-path CHECKPOINT_PATH
A path to a previous checkpoint to load the state from
and resume the state for. If provided, pretrained will
be ignored
be ignored. If using a SparseZoo recipe, can also
provide 'zoo' to load the base weights associated with
that recipe
--model-kwargs MODEL_KWARGS
kew word arguments to be passed to model constructor,
should be given as a json object
Expand Down Expand Up @@ -112,9 +114,9 @@
in as a json object
--recipe-path RECIPE_PATH
The path to the yaml file containing the modifiers and
schedule to apply them with. If set to
'transfer_learning', then will create a schedule to
enable sparse transfer learning
schedule to apply them with. Can also provide a
SparseZoo stub prefixed with 'zoo:' with an optional
'?recipe_type=' argument"
--sparse-transfer-learn
Enable sparse transfer learning modifiers to enforce
the sparsity for already sparse layers. The modifiers
Expand Down Expand Up @@ -462,6 +464,7 @@
torch_distributed_zero_first,
)
from sparseml.utils import convert_to_bool, create_dirs
from sparsezoo import Zoo


LOGGER = get_main_logger()
Expand Down Expand Up @@ -540,12 +543,20 @@ def parse_args():
"Default is None which will load the default dataset for the architecture."
" Ex can be set to imagenet, cifar10, etc",
)
checkpoint_path_help = (
"A path to a previous checkpoint to load the state from and "
"resume the state for. If provided, pretrained will be ignored"
)
if par == train_parser:
checkpoint_path_help += (
". If using a SparseZoo recipe, can also provide 'zoo' to load "
"the base weights associated with that recipe"
)
par.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="A path to a previous checkpoint to load the state from and "
"resume the state for. If provided, pretrained will be ignored",
help=checkpoint_path_help,
)
par.add_argument(
"--model-kwargs",
Expand Down Expand Up @@ -664,8 +675,8 @@ def parse_args():
type=str,
default=None,
help="The path to the yaml file containing the modifiers and "
"schedule to apply them with. If set to 'transfer_learning', "
"then will create a schedule to enable sparse transfer learning",
"schedule to apply them with. Can also provide a SparseZoo stub "
"prefixed with 'zoo:' with an optional '?recipe_type=' argument",
)
par.add_argument(
"--sparse-transfer-learn",
Expand Down Expand Up @@ -1337,6 +1348,17 @@ def main(args):
num_classes = dataset_attributes["num_classes"]

with torch_distributed_zero_first(args.local_rank): # only download once locally
if args.checkpoint_path == "zoo":
if args.recipe_path and args.recipe_path.startswith("zoo:"):
args.checkpoint_path = Zoo.download_recipe_base_framework_files(
args.recipe_path, extensions=[".pth"]
)[0]
else:
raise ValueError(
"'zoo' provided as --checkpoint-path but a SparseZoo stub"
" prefixed by 'zoo:' not provided as --recipe-path"
)

model = ModelRegistry.create(
args.arch_key,
args.pretrained,
Expand Down