diff --git a/models/export.py b/models/export.py index 32869c5e58bc..278fb476e783 100644 --- a/models/export.py +++ b/models/export.py @@ -50,7 +50,7 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe=None, resume=None, rank=-1): with torch_distributed_zero_first(rank): attempt_download(weights) # download if not found locally - check_download_sparsezoo_weights(weights) # download from sparsezoo if zoo stub + weights = check_download_sparsezoo_weights(weights) # download from sparsezoo if zoo stub ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple) else weights, map_location=device) # load checkpoint start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0 diff --git a/utils/sparse.py b/utils/sparse.py index ceafe5dc68c7..80dc222e7dd0 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -33,10 +33,10 @@ def check_download_sparsezoo_weights(path): return path - if not isinstance(path, list): - raise ValueError(f"unknown type for path given: {path}") + if isinstance(path, list): + return [check_download_sparsezoo_weights(p) for p in path] - return [check_download_sparsezoo_weights(p) for p in path] + return path class SparseMLWrapper(object):