From b4d15b2945ef5df61eb262d224b1a1dfc54947be Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 23 Feb 2021 20:25:38 -0500 Subject: [PATCH] load checkpoint file based on sparsezoo recipe in pytorch_vision script --- scripts/pytorch_vision.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/scripts/pytorch_vision.py b/scripts/pytorch_vision.py index d38fac15c66..73b2d6ff646 100644 --- a/scripts/pytorch_vision.py +++ b/scripts/pytorch_vision.py @@ -80,7 +80,9 @@ --checkpoint-path CHECKPOINT_PATH A path to a previous checkpoint to load the state from and resume the state for. If provided, pretrained will - be ignored + be ignored. If using a SparseZoo recipe, can also + provide 'zoo' to load the base weights associated with + that recipe --model-kwargs MODEL_KWARGS kew word arguments to be passed to model constructor, should be given as a json object @@ -112,9 +114,9 @@ in as a json object --recipe-path RECIPE_PATH The path to the yaml file containing the modifiers and - schedule to apply them with. If set to - 'transfer_learning', then will create a schedule to - enable sparse transfer learning + schedule to apply them with. Can also provide a + SparseZoo stub prefixed with 'zoo:' with an optional + '?recipe_type=' argument" --sparse-transfer-learn Enable sparse transfer learning modifiers to enforce the sparsity for already sparse layers. The modifiers @@ -462,6 +464,7 @@ torch_distributed_zero_first, ) from sparseml.utils import convert_to_bool, create_dirs +from sparsezoo import Zoo LOGGER = get_main_logger() @@ -540,12 +543,20 @@ def parse_args(): "Default is None which will load the default dataset for the architecture." " Ex can be set to imagenet, cifar10, etc", ) + checkpoint_path_help = ( + "A path to a previous checkpoint to load the state from and " + "resume the state for. If provided, pretrained will be ignored" + ) + if par == train_parser: + checkpoint_path_help += ( + ". If using a SparseZoo recipe, can also provide 'zoo' to load " + "the base weights associated with that recipe" + ) par.add_argument( "--checkpoint-path", type=str, default=None, - help="A path to a previous checkpoint to load the state from and " - "resume the state for. If provided, pretrained will be ignored", + help=checkpoint_path_help, ) par.add_argument( "--model-kwargs", @@ -664,8 +675,8 @@ def parse_args(): type=str, default=None, help="The path to the yaml file containing the modifiers and " - "schedule to apply them with. If set to 'transfer_learning', " - "then will create a schedule to enable sparse transfer learning", + "schedule to apply them with. Can also provide a SparseZoo stub " + "prefixed with 'zoo:' with an optional '?recipe_type=' argument", ) par.add_argument( "--sparse-transfer-learn", @@ -1337,6 +1348,17 @@ def main(args): num_classes = dataset_attributes["num_classes"] with torch_distributed_zero_first(args.local_rank): # only download once locally + if args.checkpoint_path == "zoo": + if args.recipe_path and args.recipe_path.startswith("zoo:"): + args.checkpoint_path = Zoo.download_recipe_base_framework_files( + args.recipe_path, extensions=[".pth"] + )[0] + else: + raise ValueError( + "'zoo' provided as --checkpoint-path but a SparseZoo stub" + " prefixed by 'zoo:' not provided as --recipe-path" + ) + model = ModelRegistry.create( args.arch_key, args.pretrained,