From 7978855499ababedb730bbca9bd34d46131b2d56 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 15 Oct 2021 11:28:59 -0400 Subject: [PATCH 1/6] Add framework fallback ability to execute in sparseml --- src/sparseml/base.py | 120 ++++++++++++++++++++++++++++++------------- 1 file changed, 85 insertions(+), 35 deletions(-) diff --git a/src/sparseml/base.py b/src/sparseml/base.py index 87da5e2bcc1..56e7274d896 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -15,6 +15,7 @@ import importlib import logging +from collections import OrderedDict from enum import Enum from typing import Any, List, Optional @@ -48,7 +49,22 @@ class Framework(Enum): tensorflow_v1 = "tensorflow_v1" -def detect_framework(item: Any) -> Framework: +def _execute_sparseml_package_function( + framework: Framework, function_name: str, *args, **kwargs +): + try: + module = importlib.import_module(f"sparseml.{framework.value}") + function = getattr(module, function_name) + except Exception as err: + raise ValueError( + f"unknown or unsupported framework {framework}, " + f"cannot call function {function_name}" + ) + + return function(*args, **kwargs) + + +def detect_framework(item: Any, attempt_num: int = 0) -> Framework: """ Detect the supported ML framework for a given item. Supported input types are the following: @@ -61,31 +77,55 @@ def detect_framework(item: Any) -> Framework: If the framework cannot be determined, will return Framework.unknown :param item: The item to detect the ML framework for :type item: Any + :param attempt_num: The number of detection attempts so far. + Because multiple frameworks can run the same type of file, + this enables fall backs until the correct one is able to run. + :type attempt_num: int :return: The detected framework from the given item :rtype: Framework """ _LOGGER.debug("detecting framework for %s", item) - framework = Framework.unknown if isinstance(item, Framework): _LOGGER.debug("framework detected from Framework instance") - framework = item - elif isinstance(item, str) and item.lower().strip() in Framework.__members__: + + return item if attempt_num == 0 else Framework.unknown + + if isinstance(item, str) and item.lower().strip() in Framework.__members__: _LOGGER.debug("framework detected from Framework string instance") - framework = Framework[item.lower().strip()] - else: - _LOGGER.debug("detecting framework by calling into supported frameworks") - - for test in Framework: - try: - framework = execute_in_sparseml_framework( - test, "detect_framework", item - ) - except Exception as err: - # errors are expected if the framework is not installed, log as debug - _LOGGER.debug(f"error while calling detect_framework for {test}: {err}") - if framework != Framework.unknown: + return ( + Framework[item.lower().strip()] if attempt_num == 0 else Framework.unknown + ) + + _LOGGER.debug( + "detecting framework by calling into supported frameworks " + "for attempt_num %s", + attempt_num, + ) + framework = Framework.unknown + found_count = 0 + + for test in Framework: + if test == Framework.unknown: + continue + + try: + framework = _execute_sparseml_package_function( + test, "detect_framework", item + ) + except Exception as err: + # errors are expected if the framework is not installed, log as debug + _LOGGER.debug("error while calling detect_framework for %s: %s", test, err) + + if framework != Framework.unknown and framework is not None: + found_count += 1 + + if found_count <= attempt_num: + _LOGGER.debug( + "skipping framework %s with attempt_num=%s", framework, attempt_num + ) + else: break _LOGGER.info("detected framework of %s from %s", framework, item) @@ -94,7 +134,7 @@ def detect_framework(item: Any) -> Framework: def execute_in_sparseml_framework( - framework: Framework, function_name: str, *args, **kwargs + framework: Any, function_name: str, *args, **kwargs ) -> Any: """ Execute a general function that is callable from the root of the frameworks @@ -102,7 +142,7 @@ def execute_in_sparseml_framework( Useful for benchmarking, analyzing, etc. Will pass the args and kwargs to the callable function. :param framework: The ML framework to run the function under in SparseML. - :type framework: Framework + :type framework: Any :param function_name: The name of the function in SparseML that should be run with the given args and kwargs. :type function_name: str @@ -119,25 +159,35 @@ def execute_in_sparseml_framework( kwargs, ) - if not isinstance(framework, Framework): - framework = detect_framework(framework) + framework_errs = OrderedDict() + ret = None + test_framework = detect_framework(framework) + attempt_num = 1 - if framework == Framework.unknown: - raise ValueError( - f"unknown or unsupported framework {framework}, " - f"cannot call function {function_name}" - ) + while test_framework != Framework.unknown: + function = None - try: - module = importlib.import_module(f"sparseml.{framework.value}") - function = getattr(module, function_name) - except Exception as err: - raise ValueError( - f"could not find function_name {function_name} in framework {framework}: " - f"{err}" - ) + try: + module = importlib.import_module(f"sparseml.{test_framework.value}") + function = getattr(module, function_name) - return function(*args, **kwargs) + return function(*args, **kwargs) + except Exception as err: + framework_errs[framework] = err + + test_framework = detect_framework(framework, attempt_num) + attempt_num += 1 + + if len(framework_errs) == 1: + raise list(framework_errs.values())[0] + + if len(framework_errs) > 1: + raise RuntimeError(str(framework_errs)) + + raise ValueError( + f"unknown or unsupported framework {framework}, " + f"cannot call function {function_name}" + ) def get_version( From 206045003cfb555a4fdb8965602e69c5ccc12b48 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 15 Oct 2021 11:31:00 -0400 Subject: [PATCH 2/6] remove unused variable --- src/sparseml/base.py | 1 - src/sparseml/benchmark/info.py | 20 +++----------------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/sparseml/base.py b/src/sparseml/base.py index 56e7274d896..51bc19bdaf6 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -160,7 +160,6 @@ def execute_in_sparseml_framework( ) framework_errs = OrderedDict() - ret = None test_framework = detect_framework(framework) attempt_num = 1 diff --git a/src/sparseml/benchmark/info.py b/src/sparseml/benchmark/info.py index 230d2a00d7b..48ffc32e230 100644 --- a/src/sparseml/benchmark/info.py +++ b/src/sparseml/benchmark/info.py @@ -96,7 +96,7 @@ from tqdm import auto -from sparseml.base import Framework, detect_framework, execute_in_sparseml_framework +from sparseml.base import Framework, execute_in_sparseml_framework from sparseml.benchmark.serialization import ( BatchBenchmarkResult, BenchmarkConfig, @@ -369,13 +369,8 @@ def save_benchmark_results( pass to the runner :param show_progress: True to show a tqdm bar when running, False otherwise """ - if framework is None: - framework = detect_framework(model) - else: - framework = detect_framework(framework) - results = execute_in_sparseml_framework( - framework, + framework if framework is not None else model, "run_benchmark", model, data, @@ -442,18 +437,9 @@ def load_and_run_benchmark( :param save_path: path to save the new benchmark results """ _LOGGER.info(f"rerunning benchmark {load}") - info = load_benchmark_info(load) - - framework = info.framework - - if framework is None: - framework = detect_framework(model) - else: - framework = detect_framework(framework) - save_benchmark_results( - model, + info.framework if info.framework is not None else model, data, batch_size=info.config.batch_size, iterations=info.config.iterations, From 8594772f3207563c0adbdab42a4258df6da6440b Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 19 Oct 2021 14:30:06 -0400 Subject: [PATCH 3/6] decrease complexity of falling back on framework execution --- src/sparseml/base.py | 80 ++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/src/sparseml/base.py b/src/sparseml/base.py index 51bc19bdaf6..f6cb4abfc08 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -26,6 +26,7 @@ __all__ = [ "Framework", + "detect_frameworks", "detect_framework", "execute_in_sparseml_framework", "get_version", @@ -64,9 +65,9 @@ def _execute_sparseml_package_function( return function(*args, **kwargs) -def detect_framework(item: Any, attempt_num: int = 0) -> Framework: +def detect_frameworks(item: Any) -> List[Framework]: """ - Detect the supported ML framework for a given item. + Detects the supported ML frameworks for a given item. Supported input types are the following: - A Framework enum - A string of any case representing the name of the framework @@ -74,63 +75,67 @@ def detect_framework(item: Any, attempt_num: int = 0) -> Framework: - A supported file type within the framework such as model files: (onnx, pth, h5, pb) - An object from a supported ML framework such as a model instance - If the framework cannot be determined, will return Framework.unknown + If the framework cannot be determined, an empty list will be returned + :param item: The item to detect the ML framework for :type item: Any - :param attempt_num: The number of detection attempts so far. - Because multiple frameworks can run the same type of file, - this enables fall backs until the correct one is able to run. - :type attempt_num: int - :return: The detected framework from the given item - :rtype: Framework + :return: The detected ML frameworks from the given item + :rtype: List[Framework] """ - _LOGGER.debug("detecting framework for %s", item) + _LOGGER.debug("detecting frameworks for %s", item) if isinstance(item, Framework): _LOGGER.debug("framework detected from Framework instance") - return item if attempt_num == 0 else Framework.unknown + return [item] if isinstance(item, str) and item.lower().strip() in Framework.__members__: _LOGGER.debug("framework detected from Framework string instance") - return ( - Framework[item.lower().strip()] if attempt_num == 0 else Framework.unknown - ) + return [Framework[item.lower().strip()]] - _LOGGER.debug( - "detecting framework by calling into supported frameworks " - "for attempt_num %s", - attempt_num, - ) - framework = Framework.unknown - found_count = 0 + _LOGGER.debug("detecting frameworks by calling into supported frameworks") + frameworks = [] for test in Framework: if test == Framework.unknown: continue try: - framework = _execute_sparseml_package_function( + detected = _execute_sparseml_package_function( test, "detect_framework", item ) + frameworks.append(detected) except Exception as err: # errors are expected if the framework is not installed, log as debug _LOGGER.debug("error while calling detect_framework for %s: %s", test, err) - if framework != Framework.unknown and framework is not None: - found_count += 1 + _LOGGER.info("detected frameworks of %s from %s", frameworks, item) + + return frameworks - if found_count <= attempt_num: - _LOGGER.debug( - "skipping framework %s with attempt_num=%s", framework, attempt_num - ) - else: - break - _LOGGER.info("detected framework of %s from %s", framework, item) +def detect_framework(item: Any) -> Framework: + """ + Detect the supported ML framework for a given item. + Supported input types are the following: + - A Framework enum + - A string of any case representing the name of the framework + (deepsparse, onnx, keras, pytorch, tensorflow_v1) + - A supported file type within the framework such as model files: + (onnx, pth, h5, pb) + - An object from a supported ML framework such as a model instance + If the framework cannot be determined, will return Framework.unknown + + :param item: The item to detect the ML framework for + :type item: Any + :return: The detected framework from the given item + :rtype: Framework + """ + _LOGGER.debug("detecting framework for %s", item) + frameworks = detect_frameworks(item) - return framework + return frameworks[0] if len(frameworks) > 0 else Framework.unknown def execute_in_sparseml_framework( @@ -141,6 +146,7 @@ def execute_in_sparseml_framework( package under SparseML such as sparseml.pytorch. Useful for benchmarking, analyzing, etc. Will pass the args and kwargs to the callable function. + :param framework: The ML framework to run the function under in SparseML. :type framework: Any :param function_name: The name of the function in SparseML that should be run @@ -160,12 +166,9 @@ def execute_in_sparseml_framework( ) framework_errs = OrderedDict() - test_framework = detect_framework(framework) - attempt_num = 1 - - while test_framework != Framework.unknown: - function = None + test_frameworks = detect_frameworks(framework) + for test_framework in test_frameworks: try: module = importlib.import_module(f"sparseml.{test_framework.value}") function = getattr(module, function_name) @@ -174,9 +177,6 @@ def execute_in_sparseml_framework( except Exception as err: framework_errs[framework] = err - test_framework = detect_framework(framework, attempt_num) - attempt_num += 1 - if len(framework_errs) == 1: raise list(framework_errs.values())[0] From 9a9924240e35c442c08926a2e852cd2e2ca9401b Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 19 Oct 2021 14:57:19 -0400 Subject: [PATCH 4/6] quality and test fixes --- src/sparseml/base.py | 45 +++++++++++++++++++------------------ tests/sparseml/test_base.py | 2 +- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/sparseml/base.py b/src/sparseml/base.py index f6cb4abfc08..633c15f2d2b 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -59,7 +59,7 @@ def _execute_sparseml_package_function( except Exception as err: raise ValueError( f"unknown or unsupported framework {framework}, " - f"cannot call function {function_name}" + f"cannot call function {function_name}: {err}" ) return function(*args, **kwargs) @@ -83,32 +83,33 @@ def detect_frameworks(item: Any) -> List[Framework]: :rtype: List[Framework] """ _LOGGER.debug("detecting frameworks for %s", item) - - if isinstance(item, Framework): - _LOGGER.debug("framework detected from Framework instance") - - return [item] + frameworks = [] if isinstance(item, str) and item.lower().strip() in Framework.__members__: _LOGGER.debug("framework detected from Framework string instance") + item = Framework[item.lower().strip()] - return [Framework[item.lower().strip()]] - - _LOGGER.debug("detecting frameworks by calling into supported frameworks") - frameworks = [] - - for test in Framework: - if test == Framework.unknown: - continue + if isinstance(item, Framework): + _LOGGER.debug("framework detected from Framework instance") - try: - detected = _execute_sparseml_package_function( - test, "detect_framework", item - ) - frameworks.append(detected) - except Exception as err: - # errors are expected if the framework is not installed, log as debug - _LOGGER.debug("error while calling detect_framework for %s: %s", test, err) + if item != Framework.unknown: + frameworks.append(item) + else: + _LOGGER.debug("detecting frameworks by calling into supported frameworks") + frameworks = [] + + for test in Framework: + if test == Framework.unknown: + continue + + try: + detected = _execute_sparseml_package_function( + test, "detect_framework", item + ) + frameworks.append(detected) + except Exception as err: + # errors are expected if the framework is not installed, log as debug + _LOGGER.debug("error while calling detect_framework for %s: %s", test, err) _LOGGER.info("detected frameworks of %s from %s", frameworks, item) diff --git a/tests/sparseml/test_base.py b/tests/sparseml/test_base.py index a37d970aaed..24dadb3e0b1 100644 --- a/tests/sparseml/test_base.py +++ b/tests/sparseml/test_base.py @@ -62,7 +62,7 @@ def test_execute_in_sparseml_framework(): with pytest.raises(ValueError): execute_in_sparseml_framework(Framework.unknown, "unknown") - with pytest.raises(ValueError): + with pytest.raises(ImportError): execute_in_sparseml_framework(Framework.onnx, "unknown") # TODO: fill in with sample functions to execute in frameworks once available From d74f6c388da90c2e3c51d93af3bed34828a1f120 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 19 Oct 2021 14:59:48 -0400 Subject: [PATCH 5/6] update docs --- src/sparseml/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sparseml/base.py b/src/sparseml/base.py index 633c15f2d2b..6440dbad35c 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -148,7 +148,8 @@ def execute_in_sparseml_framework( Useful for benchmarking, analyzing, etc. Will pass the args and kwargs to the callable function. - :param framework: The ML framework to run the function under in SparseML. + :param framework: The item to detect the ML framework for to run the function under, + see detect_frameworks for more details on acceptible inputs :type framework: Any :param function_name: The name of the function in SparseML that should be run with the given args and kwargs. From 643b41129d42bb4bd259fb3509c42ce1e2b267ae Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 19 Oct 2021 16:01:52 -0400 Subject: [PATCH 6/6] fix tests --- src/sparseml/base.py | 4 +++- tests/sparseml/test_base.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sparseml/base.py b/src/sparseml/base.py index 6440dbad35c..97f66a42899 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -109,7 +109,9 @@ def detect_frameworks(item: Any) -> List[Framework]: frameworks.append(detected) except Exception as err: # errors are expected if the framework is not installed, log as debug - _LOGGER.debug("error while calling detect_framework for %s: %s", test, err) + _LOGGER.debug( + "error while calling detect_framework for %s: %s", test, err + ) _LOGGER.info("detected frameworks of %s from %s", frameworks, item) diff --git a/tests/sparseml/test_base.py b/tests/sparseml/test_base.py index 24dadb3e0b1..6804f9e49dd 100644 --- a/tests/sparseml/test_base.py +++ b/tests/sparseml/test_base.py @@ -62,7 +62,7 @@ def test_execute_in_sparseml_framework(): with pytest.raises(ValueError): execute_in_sparseml_framework(Framework.unknown, "unknown") - with pytest.raises(ImportError): + with pytest.raises(Exception): execute_in_sparseml_framework(Framework.onnx, "unknown") # TODO: fill in with sample functions to execute in frameworks once available