From c792c2e8568b96a9d5a6b4ebc4955011f7d7c054 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 30 Mar 2021 13:43:43 -0400 Subject: [PATCH 1/2] YOLOv3 integration sparsezoo weights fixes --- integrations/ultralytics/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/ultralytics/train.py b/integrations/ultralytics/train.py index c3bec048991..f4b67ebb5d8 100644 --- a/integrations/ultralytics/train.py +++ b/integrations/ultralytics/train.py @@ -104,7 +104,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check # Model - pretrained = weights.endswith('.pt') + pretrained = weights.endswith('.pt') or weights.endswith('.pth') # SparseML integration if pretrained: with torch_distributed_zero_first(rank): attempt_download(weights) # download if not found locally @@ -643,14 +643,14 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When --weights is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " - f"stub. sparseml-recipe was set to {args.sparseml_recipe}" + f"stub. sparseml-recipe was set to {opt.sparseml_recipe}" ) elif opt.weights.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(opt.weights) - args.initial_checkpoint = zoo_model.download_framework_files( + opt.weights = zoo_model.download_framework_files( extensions=[".pt", ".pth"] - ) + )[0] #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### From 22e8b26b0d822ae064d90cd16ee62c2df1ab9a96 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 30 Mar 2021 13:45:43 -0400 Subject: [PATCH 2/2] skip git status check in train.py --- integrations/ultralytics/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/ultralytics/train.py b/integrations/ultralytics/train.py index f4b67ebb5d8..c0d949246b6 100644 --- a/integrations/ultralytics/train.py +++ b/integrations/ultralytics/train.py @@ -623,7 +623,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 set_logging(opt.global_rank) if opt.global_rank in [-1, 0]: - check_git_status() + # check_git_status() SparseML integration, will be out of sync with master check_requirements() ####################################################################################