diff --git a/Makefile b/Makefile index e1d58e11..4aed1709 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,7 @@ PYCHECKGLOBS := 'examples/**/*.py' 'scripts/**/*.py' 'src/**/*.py' 'tests/**/*.p DOCDIR := docs MDCHECKGLOBS := 'docs/**/*.md' 'docs/**/*.rst' 'examples/**/*.md' 'notebooks/**/*.md' 'scripts/**/*.md' MDCHECKFILES := CODE_OF_CONDUCT.md CONTRIBUTING.md DEVELOPING.md README.md +SPARSEZOO_TEST_MODE := "true" BUILD_ARGS := # set nightly to build nightly release TARGETS := "" # targets for running pytests: full,efficientnet,inception,resnet,vgg,ssd,yolo diff --git a/src/sparsezoo/requests/base.py b/src/sparsezoo/requests/base.py index aba30d6e..64235118 100644 --- a/src/sparsezoo/requests/base.py +++ b/src/sparsezoo/requests/base.py @@ -19,8 +19,12 @@ import os from typing import Any, List, Union +from sparsezoo.utils import convert_to_bool -__all__ = ["BASE_API_URL", "ModelArgs", "MODELS_API_URL"] + +__all__ = ["BASE_API_URL", "ModelArgs", "MODELS_API_URL", "SPARSEZOO_TEST_MODE"] + +SPARSEZOO_TEST_MODE = convert_to_bool(os.getenv("SPARSEZOO_TEST_MODE")) BASE_API_URL = ( os.getenv("SPARSEZOO_API_URL") diff --git a/src/sparsezoo/requests/download.py b/src/sparsezoo/requests/download.py index 9092772b..a00113e6 100644 --- a/src/sparsezoo/requests/download.py +++ b/src/sparsezoo/requests/download.py @@ -22,7 +22,7 @@ import requests from sparsezoo.requests.authentication import get_auth_header -from sparsezoo.requests.base import MODELS_API_URL, ModelArgs +from sparsezoo.requests.base import MODELS_API_URL, SPARSEZOO_TEST_MODE, ModelArgs __all__ = ["download_get_request", "DOWNLOAD_PATH"] @@ -52,11 +52,18 @@ def download_get_request( if file_name: url = f"{url}/{file_name}" + download_args = [] + if hasattr(args, "release_version") and args.release_version: - url = f"{url}?release_version={args.release_version}" + download_args.append(f"release_version={args.release_version}") - _LOGGER.debug(f"GET download from {url}") + if SPARSEZOO_TEST_MODE: + download_args.append("increment_download=False") + if download_args: + url = f"{url}?{'&'.join(download_args)}" + + _LOGGER.debug(f"GET download from {url}") response = requests.get(url=url, headers=header) response.raise_for_status() response_json = response.json() diff --git a/src/sparsezoo/requests/search.py b/src/sparsezoo/requests/search.py index ad43dded..8b93b00c 100644 --- a/src/sparsezoo/requests/search.py +++ b/src/sparsezoo/requests/search.py @@ -59,7 +59,7 @@ def search_get_request( search_args.extend([f"page={page}", f"page_length={page_length}"]) if args.release_version: - search_args.extend(f"release_version={args.release_version}") + search_args.append(f"release_version={args.release_version}") search_args = "&".join(search_args) url = f"{MODELS_API_URL}/{SEARCH_PATH}/{args.model_url_root}?{search_args}" diff --git a/src/sparsezoo/utils/helpers.py b/src/sparsezoo/utils/helpers.py index a3b87dc4..1e0c99dc 100644 --- a/src/sparsezoo/utils/helpers.py +++ b/src/sparsezoo/utils/helpers.py @@ -18,7 +18,7 @@ import errno import os -from typing import Union +from typing import Any, Union from tqdm import auto, tqdm, tqdm_notebook @@ -26,6 +26,7 @@ __all__ = [ "CACHE_DIR", "clean_path", + "convert_to_bool", "create_dirs", "create_parent_dirs", "create_tqdm_auto_constructor", @@ -43,6 +44,18 @@ def clean_path(path: str) -> str: return os.path.abspath(os.path.expanduser(path)) +def convert_to_bool(val: Any): + """ + :param val: a value + :return: False if value is a Falsy value e.g. 0, f, false, None, otherwise True. + """ + return ( + bool(val) + if not isinstance(val, str) + else bool(val) and "f" not in val.lower() and "0" not in val.lower() + ) + + def create_dirs(path: str): """ :param path: the directory path to try and create