diff --git a/src/sparseml/base.py b/src/sparseml/base.py index 87da5e2bcc1..97f66a42899 100644 --- a/src/sparseml/base.py +++ b/src/sparseml/base.py @@ -15,6 +15,7 @@ import importlib import logging +from collections import OrderedDict from enum import Enum from typing import Any, List, Optional @@ -25,6 +26,7 @@ __all__ = [ "Framework", + "detect_frameworks", "detect_framework", "execute_in_sparseml_framework", "get_version", @@ -48,9 +50,24 @@ class Framework(Enum): tensorflow_v1 = "tensorflow_v1" -def detect_framework(item: Any) -> Framework: +def _execute_sparseml_package_function( + framework: Framework, function_name: str, *args, **kwargs +): + try: + module = importlib.import_module(f"sparseml.{framework.value}") + function = getattr(module, function_name) + except Exception as err: + raise ValueError( + f"unknown or unsupported framework {framework}, " + f"cannot call function {function_name}: {err}" + ) + + return function(*args, **kwargs) + + +def detect_frameworks(item: Any) -> List[Framework]: """ - Detect the supported ML framework for a given item. + Detects the supported ML frameworks for a given item. Supported input types are the following: - A Framework enum - A string of any case representing the name of the framework @@ -58,51 +75,84 @@ def detect_framework(item: Any) -> Framework: - A supported file type within the framework such as model files: (onnx, pth, h5, pb) - An object from a supported ML framework such as a model instance - If the framework cannot be determined, will return Framework.unknown + If the framework cannot be determined, an empty list will be returned + :param item: The item to detect the ML framework for :type item: Any - :return: The detected framework from the given item - :rtype: Framework + :return: The detected ML frameworks from the given item + :rtype: List[Framework] """ - _LOGGER.debug("detecting framework for %s", item) - framework = Framework.unknown + _LOGGER.debug("detecting frameworks for %s", item) + frameworks = [] + + if isinstance(item, str) and item.lower().strip() in Framework.__members__: + _LOGGER.debug("framework detected from Framework string instance") + item = Framework[item.lower().strip()] if isinstance(item, Framework): _LOGGER.debug("framework detected from Framework instance") - framework = item - elif isinstance(item, str) and item.lower().strip() in Framework.__members__: - _LOGGER.debug("framework detected from Framework string instance") - framework = Framework[item.lower().strip()] + + if item != Framework.unknown: + frameworks.append(item) else: - _LOGGER.debug("detecting framework by calling into supported frameworks") + _LOGGER.debug("detecting frameworks by calling into supported frameworks") + frameworks = [] for test in Framework: + if test == Framework.unknown: + continue + try: - framework = execute_in_sparseml_framework( + detected = _execute_sparseml_package_function( test, "detect_framework", item ) + frameworks.append(detected) except Exception as err: # errors are expected if the framework is not installed, log as debug - _LOGGER.debug(f"error while calling detect_framework for {test}: {err}") + _LOGGER.debug( + "error while calling detect_framework for %s: %s", test, err + ) + + _LOGGER.info("detected frameworks of %s from %s", frameworks, item) + + return frameworks - if framework != Framework.unknown: - break - _LOGGER.info("detected framework of %s from %s", framework, item) +def detect_framework(item: Any) -> Framework: + """ + Detect the supported ML framework for a given item. + Supported input types are the following: + - A Framework enum + - A string of any case representing the name of the framework + (deepsparse, onnx, keras, pytorch, tensorflow_v1) + - A supported file type within the framework such as model files: + (onnx, pth, h5, pb) + - An object from a supported ML framework such as a model instance + If the framework cannot be determined, will return Framework.unknown + + :param item: The item to detect the ML framework for + :type item: Any + :return: The detected framework from the given item + :rtype: Framework + """ + _LOGGER.debug("detecting framework for %s", item) + frameworks = detect_frameworks(item) - return framework + return frameworks[0] if len(frameworks) > 0 else Framework.unknown def execute_in_sparseml_framework( - framework: Framework, function_name: str, *args, **kwargs + framework: Any, function_name: str, *args, **kwargs ) -> Any: """ Execute a general function that is callable from the root of the frameworks package under SparseML such as sparseml.pytorch. Useful for benchmarking, analyzing, etc. Will pass the args and kwargs to the callable function. - :param framework: The ML framework to run the function under in SparseML. - :type framework: Framework + + :param framework: The item to detect the ML framework for to run the function under, + see detect_frameworks for more details on acceptible inputs + :type framework: Any :param function_name: The name of the function in SparseML that should be run with the given args and kwargs. :type function_name: str @@ -119,25 +169,28 @@ def execute_in_sparseml_framework( kwargs, ) - if not isinstance(framework, Framework): - framework = detect_framework(framework) + framework_errs = OrderedDict() + test_frameworks = detect_frameworks(framework) - if framework == Framework.unknown: - raise ValueError( - f"unknown or unsupported framework {framework}, " - f"cannot call function {function_name}" - ) + for test_framework in test_frameworks: + try: + module = importlib.import_module(f"sparseml.{test_framework.value}") + function = getattr(module, function_name) - try: - module = importlib.import_module(f"sparseml.{framework.value}") - function = getattr(module, function_name) - except Exception as err: - raise ValueError( - f"could not find function_name {function_name} in framework {framework}: " - f"{err}" - ) + return function(*args, **kwargs) + except Exception as err: + framework_errs[framework] = err - return function(*args, **kwargs) + if len(framework_errs) == 1: + raise list(framework_errs.values())[0] + + if len(framework_errs) > 1: + raise RuntimeError(str(framework_errs)) + + raise ValueError( + f"unknown or unsupported framework {framework}, " + f"cannot call function {function_name}" + ) def get_version( diff --git a/src/sparseml/benchmark/info.py b/src/sparseml/benchmark/info.py index 230d2a00d7b..48ffc32e230 100644 --- a/src/sparseml/benchmark/info.py +++ b/src/sparseml/benchmark/info.py @@ -96,7 +96,7 @@ from tqdm import auto -from sparseml.base import Framework, detect_framework, execute_in_sparseml_framework +from sparseml.base import Framework, execute_in_sparseml_framework from sparseml.benchmark.serialization import ( BatchBenchmarkResult, BenchmarkConfig, @@ -369,13 +369,8 @@ def save_benchmark_results( pass to the runner :param show_progress: True to show a tqdm bar when running, False otherwise """ - if framework is None: - framework = detect_framework(model) - else: - framework = detect_framework(framework) - results = execute_in_sparseml_framework( - framework, + framework if framework is not None else model, "run_benchmark", model, data, @@ -442,18 +437,9 @@ def load_and_run_benchmark( :param save_path: path to save the new benchmark results """ _LOGGER.info(f"rerunning benchmark {load}") - info = load_benchmark_info(load) - - framework = info.framework - - if framework is None: - framework = detect_framework(model) - else: - framework = detect_framework(framework) - save_benchmark_results( - model, + info.framework if info.framework is not None else model, data, batch_size=info.config.batch_size, iterations=info.config.iterations, diff --git a/tests/sparseml/test_base.py b/tests/sparseml/test_base.py index a37d970aaed..6804f9e49dd 100644 --- a/tests/sparseml/test_base.py +++ b/tests/sparseml/test_base.py @@ -62,7 +62,7 @@ def test_execute_in_sparseml_framework(): with pytest.raises(ValueError): execute_in_sparseml_framework(Framework.unknown, "unknown") - with pytest.raises(ValueError): + with pytest.raises(Exception): execute_in_sparseml_framework(Framework.onnx, "unknown") # TODO: fill in with sample functions to execute in frameworks once available