diff --git a/docs/conf.py b/docs/conf.py index 7d0e7861a6d..d14150a9a41 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -52,6 +52,7 @@ "sphinx_copybutton", "sphinx_markdown_tables", "sphinx_multiversion", + "sphinx-pydantic", "sphinx_rtd_theme", "recommonmark", ] @@ -60,19 +61,19 @@ templates_path = ["_templates"] # Whitelist pattern for tags (set to None to ignore all tags) -smv_tag_whitelist = r'^v.*$' +smv_tag_whitelist = r"^v.*$" # Whitelist pattern for branches (set to None to ignore all branches) -smv_branch_whitelist = r'^main$' +smv_branch_whitelist = r"^main$" # Whitelist pattern for remotes (set to None to use local branches only) -smv_remote_whitelist = r'^.*$' +smv_remote_whitelist = r"^.*$" # Pattern for released versions -smv_released_pattern = r'^tags/v.*$' +smv_released_pattern = r"^tags/v.*$" # Format for versioned output directories inside the build directory -smv_outputdir_format = '{ref.name}' +smv_outputdir_format = "{ref.name}" # Determines whether remote or local git branches/tags are preferred if their output dirs conflict smv_prefer_remote_refs = False @@ -111,8 +112,8 @@ html_logo = "source/icon-sparseml.png" html_theme_options = { - 'analytics_id': 'UA-128364174-1', # Provided by Google in your dashboard - 'analytics_anonymize_ip': False, + "analytics_id": "UA-128364174-1", # Provided by Google in your dashboard + "analytics_anonymize_ip": False, } # Add any paths that contain custom static files (such as style sheets) here, @@ -153,7 +154,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "sparseml.tex", "SparseML Documentation", [author], "manual",), + ( + master_doc, + "sparseml.tex", + "SparseML Documentation", + [author], + "manual", + ), ] # -- Options for manual page output ------------------------------------------ diff --git a/setup.cfg b/setup.cfg index 8bf94b9220c..00836589661 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True known_first_party = sparseml,sparsezoo,tests -known_third_party = bs4,requests,packaging,yaml,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2 +known_third_party = bs4,requests,packaging,yaml,pydantic,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2 sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER line_length = 88 diff --git a/setup.py b/setup.py index 460bb2ff0e5..e1835dff998 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,9 @@ "onnx>=1.5.0,<1.8.0", "onnxruntime>=1.0.0", "pandas<1.0.0", + "packaging>=20.0", "psutil>=5.0.0", + "pydantic>=1.0.0", "requests>=2.0.0", "scikit-image>=0.15.0", "scipy>=1.0.0", @@ -80,6 +82,8 @@ "sphinx-copybutton>=0.3.0", "sphinx-markdown-tables>=0.0.15", "sphinx-multiversion==0.2.4", + "sphinx-pydantic>=0.1.0", + "sphinx-rtd-theme>=0.5.0", "wheel>=0.36.2", "pytest>=6.0.0", "flaky>=3.0.0", @@ -114,7 +118,12 @@ def _setup_extras() -> Dict: def _setup_entry_points() -> Dict: - return {} + return { + "console_scripts": [ + "sparseml.framework=sparseml.framework.info:_main", + "sparseml.sparsification=sparseml.sparsification.info:_main", + ] + } def _setup_long_description() -> Tuple[str, str]: diff --git a/src/sparseml/__init__.py b/src/sparseml/__init__.py index b482253b7ef..9e93adcd59f 100644 --- a/src/sparseml/__init__.py +++ b/src/sparseml/__init__.py @@ -19,8 +19,27 @@ # flake8: noqa # isort: skip_file -from .version import * - # be sure to import all logging first and at the root # this keeps other loggers in nested files creating from the root logger setups from .log import * +from .version import * + +from .base import ( + Framework, + check_version, + detect_framework, + execute_in_sparseml_framework, +) +from .framework import ( + FrameworkInferenceProviderInfo, + FrameworkInfo, + framework_info, + save_framework_info, + load_framework_info, +) +from .sparsification import ( + SparsificationInfo, + sparsification_info, + save_sparsification_info, + load_sparsification_info, +) diff --git a/src/sparseml/base.py b/src/sparseml/base.py new file mode 100644 index 00000000000..d638dbe3ec9 --- /dev/null +++ b/src/sparseml/base.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import logging +from enum import Enum +from typing import Any, Optional + +from packaging import version + +import pkg_resources + + +__all__ = [ + "Framework", + "detect_framework", + "execute_in_sparseml_framework", + "get_version", + "check_version", +] + + +_LOGGER = logging.getLogger(__name__) + + +class Framework(Enum): + """ + Framework types known of/supported within the sparseml/deepsparse ecosystem + """ + + unknown = "unknown" + deepsparse = "deepsparse" + onnx = "onnx" + keras = "keras" + pytorch = "pytorch" + tensorflow_v1 = "tensorflow_v1" + + +def detect_framework(item: Any) -> Framework: + """ + Detect the supported ML framework for a given item. + Supported input types are the following: + - A Framework enum + - A string of any case representing the name of the framework + (deepsparse, onnx, keras, pytorch, tensorflow_v1) + - A supported file type within the framework such as model files: + (onnx, pth, h5, pb) + - An object from a supported ML framework such as a model instance + If the framework cannot be determined, will return Framework.unknown + :param item: The item to detect the ML framework for + :type item: Any + :return: The detected framework from the given item + :rtype: Framework + """ + _LOGGER.debug("detecting framework for %s", item) + framework = Framework.unknown + + if isinstance(item, Framework): + _LOGGER.debug("framework detected from Framework instance") + framework = item + elif isinstance(item, str) and item.lower().strip() in Framework.__members__: + _LOGGER.debug("framework detected from Framework string instance") + framework = Framework[item.lower().strip()] + else: + _LOGGER.debug("detecting framework by calling into supported frameworks") + + for test in Framework: + try: + framework = execute_in_sparseml_framework( + test, "detect_framework", item + ) + except Exception as err: + # errors are expected if the framework is not installed, log as debug + _LOGGER.debug(f"error while calling detect_framework for {test}: {err}") + + if framework != Framework.unknown: + break + + _LOGGER.info("detected framework of %s from %s", framework, item) + + return framework + + +def execute_in_sparseml_framework( + framework: Framework, function_name: str, *args, **kwargs +) -> Any: + """ + Execute a general function that is callable from the root of the frameworks + package under SparseML such as sparseml.pytorch. + Useful for benchmarking, analyzing, etc. + Will pass the args and kwargs to the callable function. + :param framework: The ML framework to run the function under in SparseML. + :type framework: Framework + :param function_name: The name of the function in SparseML that should be run + with the given args and kwargs. + :type function_name: str + :param args: Any positional args to be passed into the function. + :param kwargs: Any key word args to be passed into the function. + :return: The return value from the executed function. + :rtype: Any + """ + _LOGGER.debug( + "executing function with name %s for framework %s, args %s, kwargs %s", + function_name, + framework, + args, + kwargs, + ) + + if not isinstance(framework, Framework): + framework = detect_framework(framework) + + if framework == Framework.unknown: + raise ValueError( + f"unknown or unsupported framework {framework}, " + f"cannot call function {function_name}" + ) + + try: + module = importlib.import_module(f"sparseml.{framework.value}") + function = getattr(module, function_name) + except Exception as err: + raise ValueError( + f"could not find function_name {function_name} in framework {framework}: " + f"{err}" + ) + + return function(*args, **kwargs) + + +def get_version(package_name: str, raise_on_error: bool) -> Optional[str]: + """ + :param package_name: The name of the full package, as it would be imported, + to get the version for + :type package_name: str + :param raise_on_error: True to raise an error if package is not installed + or couldn't be imported, False to return None + :return: the version of the desired package if detected, otherwise raises an error + :rtype: str + """ + + try: + current_version: str = pkg_resources.get_distribution(package_name).version + except Exception as err: + if raise_on_error: + raise ImportError( + f"error while getting current version for {package_name}: {err}" + ) + + return None + + return current_version + + +def check_version( + package_name: str, + min_version: Optional[str] = None, + max_version: Optional[str] = None, + raise_on_error: bool = True, +) -> bool: + """ + :param package_name: the name of the package to check the version of + :type package_name: str + :param min_version: The minimum version for the package that it must be greater than + or equal to, if unset will require no minimum version + :type min_version: str + :param max_version: The maximum version for the package that it must be less than + or equal to, if unset will require no maximum version. + :type max_version: str + :param raise_on_error: True to raise any issues such as not installed, + minimum version, or maximum version as ImportError. False to return the result. + :type raise_on_error: bool + :return: If raise_on_error, will return False if the package is not installed + or the version is outside the accepted bounds and True if everything is correct. + :rtype: bool + """ + current_version = get_version(package_name, raise_on_error) + + if not current_version: + return False + + current_version = version.parse(current_version) + min_version = version.parse(min_version) if min_version else None + max_version = version.parse(max_version) if max_version else None + + if min_version and current_version < min_version: + if raise_on_error: + raise ImportError( + f"required min {package_name} version {min_version}, " + f"found {current_version}" + ) + return False + + if max_version and current_version > max_version: + if raise_on_error: + raise ImportError( + f"required min {package_name} version {min_version}, " + f"found {current_version}" + ) + return False + + return True diff --git a/src/sparseml/framework/__init__.py b/src/sparseml/framework/__init__.py new file mode 100644 index 00000000000..0d0760dce22 --- /dev/null +++ b/src/sparseml/framework/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality related to integrating with, detecting, and getting information for +support and sparsification in ML frameworks. +""" + +# flake8: noqa + +from .info import * diff --git a/src/sparseml/framework/info.py b/src/sparseml/framework/info.py new file mode 100644 index 00000000000..fcee41c806e --- /dev/null +++ b/src/sparseml/framework/info.py @@ -0,0 +1,299 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality related to integrating with, detecting, and getting information for +support and sparsification in ML frameworks. + +The file is executable and will get the framework info for a given framework: + +########## +Command help: +usage: info.py [-h] [--path PATH] framework + +Compile the available setup and information for a given framework. + +positional arguments: + framework the ML framework or path to a framework file to load the + framework info for + +optional arguments: + -h, --help show this help message and exit + --path PATH A full file path to save the framework info to. If not + supplied, will print out the framework info to the + console. + +######### +EXAMPLES +######### + +########## +Example command for getting the framework info for pytorch. +python src/sparseml/framework/info.py pytorch +""" + +import argparse +import logging +import os +from collections import OrderedDict +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from sparseml.base import Framework, execute_in_sparseml_framework +from sparseml.sparsification.info import SparsificationInfo +from sparseml.utils import clean_path, create_parent_dirs + + +__all__ = [ + "FrameworkInferenceProviderInfo", + "FrameworkInfo", + "framework_info", + "save_framework_info", + "load_framework_info", +] + + +_LOGGER = logging.getLogger(__name__) + + +class FrameworkInferenceProviderInfo(BaseModel): + """ + Class for storing information for an inference provider within a frameworks engine. + For example, the gpu provider within PyTorch. + Extends pydantics BaseModel class for serialization to and from json + in addition to proper type checking on construction. + """ + + name: str = Field(title="name", description="The name/id of the provider.") + description: str = Field( + title="description", + description="A description for the provider to offer more detail.", + ) + device: str = Field( + title="device", description="The device the provider is for such as cpu or gpu." + ) + supported_sparsification: Optional[SparsificationInfo] = Field( + default=None, + title="supported_sparsification", + description=( + "The supported sparsification information for support " + "for inference speedup in the provider." + ), + ) + available: bool = Field( + default=False, + title="available", + description="True if the provider is available on the system, False otherwise.", + ) + properties: Dict[str, Any] = Field( + default=OrderedDict(), + title="properties", + description="Any properties for the given provider.", + ) + warnings: List[str] = Field( + default=[], + title="warnings", + description="Any warnings/restrictions for the provider on the given system.", + ) + + +class FrameworkInfo(BaseModel): + """ + Class for storing the information for an ML frameworks info and availability + on the current system. + Extends pydantics BaseModel class for serialization to and from json + in addition to proper type checking on construction. + """ + + framework: Framework = Field( + title="framework", description="The framework the system info is for." + ) + package_versions: Dict[str, Optional[str]] = Field( + title="package_versions", + description=( + "A mapping of the package and supporting packages for a given framework " + "to the detected versions on the system currently. " + "If the package is not detected, will be set to None." + ), + ) + sparsification: Optional[SparsificationInfo] = Field( + default=None, + title="sparsification", + description=( + "True if inference for a model is available on the system " + "for the given framework, False otherwise." + ), + ) + inference_providers: List[FrameworkInferenceProviderInfo] = Field( + default=[], + title="inference_providers", + description=( + "True if inference for a model is available on the system " + "for the given framework, False otherwise." + ), + ) + properties: Dict[str, Any] = Field( + default={}, + title="properties", + description="Any additional properties for the framework.", + ) + training_available: bool = Field( + default=False, + title="training_available", + description=( + "True if training/editing a model is available on the system " + "for the given framework, False otherwise." + ), + ) + sparsification_available: bool = Field( + default=False, + title="sparsification_available", + description=( + "True if sparsifying a model is available on the system " + "for the given framework, False otherwise." + ), + ) + exporting_onnx_available: bool = Field( + default=False, + title="exporting_onnx_available", + description=( + "True if exporting a model in the ONNX format is available on the system " + "for the given framework, False otherwise." + ), + ) + inference_available: bool = Field( + default=False, + title="inference_available", + description=( + "True if inference for a model is available on the system " + "for the given framework, False otherwise." + ), + ) + + +def framework_info(framework: Any) -> FrameworkInfo: + """ + Detect the information for the given ML framework such as package versions, + availability for core actions such as training and inference, + sparsification support, and inference provider support. + + :param framework: The item to detect the ML framework for. + See :func:`detect_framework` for more information. + :type framework: Any + :return: The framework info for the given framework + :rtype: FrameworkInfo + """ + _LOGGER.debug("getting system info for framework %s", framework) + info: FrameworkInfo = execute_in_sparseml_framework(framework, "framework_info") + _LOGGER.info("retrieved system info for framework %s: %s", framework, info) + + return info + + +def save_framework_info(framework: Any, path: Optional[str] = None): + """ + Save the framework info for a given framework. + If path is provided, will save to a json file at that path. + If path is not provided, will print out the info. + + :param framework: The item to detect the ML framework for. + See :func:`detect_framework` for more information. + :type framework: Any + :param path: The path, if any, to save the info to in json format. + If not provided will print out the info. + :type path: Optional[str] + """ + _LOGGER.debug( + "saving framework info for framework %s to %s", + framework, + path if path else "sys.out", + ) + info = ( + framework_info(framework) + if not isinstance(framework, FrameworkInfo) + else framework + ) + + if path: + path = clean_path(path) + create_parent_dirs(path) + + with open(path, "w") as file: + file.write(info.json()) + + _LOGGER.info( + "saved framework info for framework %s in file at %s", framework, path + ), + else: + print(info.json(indent=4)) + _LOGGER.info("printed out framework info for framework %s", framework) + + +def load_framework_info(load: str) -> FrameworkInfo: + """ + Load the framework info from a file or raw json. + If load exists as a path, will read from the file and use that. + Otherwise will try to parse the input as a raw json str. + + :param load: Either a file path to a json file or a raw json string. + :type load: str + :return: The loaded framework info. + :rtype: FrameworkInfo + """ + loaded_path = clean_path(load) + + if os.path.exists(loaded_path): + with open(loaded_path, "r") as file: + load = file.read() + + info = FrameworkInfo.parse_raw(load) + + return info + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=( + "Compile the available setup and information for a given framework." + ) + ) + parser.add_argument( + "framework", + type=str, + help=( + "the ML framework or path to a framework file to load the " + "framework info for" + ), + ) + parser.add_argument( + "--path", + type=str, + default=None, + help=( + "A full file path to save the framework info to. " + "If not supplied, will print out the framework info to the console." + ), + ) + + return parser.parse_args() + + +def _main(): + args = _parse_args() + save_framework_info(args.framework, args.path) + + +if __name__ == "__main__": + _main() diff --git a/src/sparseml/sparsification/__init__.py b/src/sparseml/sparsification/__init__.py new file mode 100644 index 00000000000..f7853a24393 --- /dev/null +++ b/src/sparseml/sparsification/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality related to applying, describing, and supporting sparsification +algorithms to models within in ML frameworks. +""" + +# flake8: noqa + +from .info import * diff --git a/src/sparseml/sparsification/info.py b/src/sparseml/sparsification/info.py new file mode 100644 index 00000000000..8aa2af8a29d --- /dev/null +++ b/src/sparseml/sparsification/info.py @@ -0,0 +1,304 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality related to describing availability and information of sparsification +algorithms to models within in the ML frameworks. + +The file is executable and will get the sparsification info for a given framework: + +########## +Command help: +usage: info.py [-h] [--path PATH] framework + +Compile the available setup and information for the sparsification of a model +in a given framework. + +positional arguments: + framework the ML framework or path to a framework file to load the + sparsification info for + +optional arguments: + -h, --help show this help message and exit + --path PATH A full file path to save the sparsification info to. If not + supplied, will print out the sparsification info to the + console. + +######### +EXAMPLES +######### + +########## +Example command for getting the sparsification info for pytorch. +python src/sparseml/sparsification/info.py pytorch +""" + +import argparse +import logging +import os +from enum import Enum +from typing import Any, List, Optional + +from pydantic import BaseModel, Field + +from sparseml.base import execute_in_sparseml_framework +from sparseml.utils import clean_path, create_parent_dirs + + +__all__ = [ + "ModifierType", + "ModifierPropInfo", + "ModifierInfo", + "SparsificationInfo", + "sparsification_info", + "save_sparsification_info", + "load_sparsification_info", +] + + +_LOGGER = logging.getLogger(__name__) + + +class ModifierType(Enum): + """ + Types of modifiers for grouping what functionality a Modifier falls under. + """ + + general = "general" + training = "training" + pruning = "pruning" + quantization = "quantization" + act_sparsity = "act_sparsity" + misc = "misc" + + +class ModifierPropInfo(BaseModel): + """ + Class for storing information and associated metadata for a + property on a given Modifier. + Extends pydantics BaseModel class for serialization to and from json + in addition to proper type checking on construction. + """ + + name: str = Field( + title="name", + description=( + "Name of the property for a Modifier. " + "It can be accessed by this name on the modifier instance." + ), + ) + description: str = Field( + title="description", + description="Description and information for the property for a Modifier.", + ) + type_: str = Field( + title="type_", + description=( + "The format type for the property for a Modifier such as " + "int, float, str, etc." + ), + ) + restrictions: Optional[List[Any]] = Field( + default=None, + title="restrictions", + description=( + "Value restrictions for the property for a Modifier. " + "If set, restrict the set value to one of the contained restrictions." + ), + ) + + +class ModifierInfo(BaseModel): + """ + Class for storing information and associated metadata for a given Modifier. + Extends pydantics BaseModel class for serialization to and from json + in addition to proper type checking on construction. + """ + + name: str = Field( + title="name", + description=( + "Name/class of the Modifier to be used for construction and identification." + ), + ) + description: str = Field( + title="description", + description="Description and info for the Modifier and what its used for.", + ) + type_: ModifierType = Field( + default=ModifierType.misc, + title="type_", + description=( + "The type the given Modifier is for grouping by similar functionality." + ), + ) + props: List[ModifierPropInfo] = Field( + default=[], + title="props", + description="The properties for the Modifier that can be set and controlled.", + ) + warnings: Optional[List[str]] = Field( + default=None, + title="warnings", + description=( + "Any warnings that apply for the Modifier and using it within a system" + ), + ) + + +class SparsificationInfo(BaseModel): + """ + Class for storing the information for sparsifying in a given framework. + Extends pydantics BaseModel class for serialization to and from json + in addition to proper type checking on construction. + """ + + modifiers: List[ModifierInfo] = Field( + default=[], + title="modifiers", + description="A list of the information for the available modifiers", + ) + + def type_modifiers(self, type_: ModifierType) -> List[ModifierInfo]: + """ + Get the contained Modifiers for a specific ModifierType. + :param type_: The ModifierType to filter the returned list of Modifiers by. + :type type_: ModifierType + :return: The filtered list of Modifiers that match the given type_. + :rtype: List[ModifierInfo] + """ + modifiers = [] + + for mod in self.modifiers: + if mod.type_ == type_: + modifiers.append(mod) + + return modifiers + + +def sparsification_info(framework: Any) -> SparsificationInfo: + """ + Get the available setup for sparsifying model in the given framework. + + :param framework: The item to detect the ML framework for. + See :func:`detect_framework` for more information. + :type framework: Any + :return: The sparsification info for the given framework + :rtype: SparsificationInfo + """ + _LOGGER.debug("getting sparsification info for framework %s", framework) + info: SparsificationInfo = execute_in_sparseml_framework( + framework, "sparsification_info" + ) + _LOGGER.info("retrieved sparsification info for framework %s: %s", framework, info) + + return info + + +def save_sparsification_info(framework: Any, path: Optional[str] = None): + """ + Save the sparsification info for a given framework. + If path is provided, will save to a json file at that path. + If path is not provided, will print out the info. + + :param framework: The item to detect the ML framework for. + See :func:`detect_framework` for more information. + :type framework: Any + :param path: The path, if any, to save the info to in json format. + If not provided will print out the info. + :type path: Optional[str] + """ + _LOGGER.debug( + "saving sparsification info for framework %s to %s", + framework, + path if path else "sys.out", + ) + info = ( + sparsification_info(framework) + if not isinstance(framework, SparsificationInfo) + else framework + ) + + if path: + path = clean_path(path) + create_parent_dirs(path) + + with open(path, "w") as file: + file.write(info.json()) + + _LOGGER.info( + "saved sparsification info for framework %s in file at %s", framework, path + ), + else: + print(info.json(indent=4)) + _LOGGER.info("printed out sparsification info for framework %s", framework) + + +def load_sparsification_info(load: str) -> SparsificationInfo: + """ + Load the sparsification info from a file or raw json. + If load exists as a path, will read from the file and use that. + Otherwise will try to parse the input as a raw json str. + + :param load: Either a file path to a json file or a raw json string. + :type load: str + :return: The loaded sparsification info. + :rtype: SparsificationInfo + """ + load_path = clean_path(load) + + if os.path.exists(load_path): + with open(load_path, "r") as file: + load = file.read() + + info = SparsificationInfo.parse_raw(load) + + return info + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=( + "Compile the available setup and information for the sparsification " + "of a model in a given framework." + ) + ) + parser.add_argument( + "framework", + type=str, + help=( + "the ML framework or path to a framework file to load the " + "sparsification info for" + ), + ) + parser.add_argument( + "--path", + type=str, + default=None, + help=( + "A full file path to save the sparsification info to. " + "If not supplied, will print out the sparsification info to the console." + ), + ) + + return parser.parse_args() + + +def _main(): + args = _parse_args() + save_sparsification_info(args.framework, args.path) + + +if __name__ == "__main__": + _main() diff --git a/tests/sparseml/framework/__init__.py b/tests/sparseml/framework/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/framework/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/framework/test_info.py b/tests/sparseml/framework/test_info.py new file mode 100644 index 00000000000..335c8868ca7 --- /dev/null +++ b/tests/sparseml/framework/test_info.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile + +import pytest + +from sparseml.base import Framework +from sparseml.framework import ( + FrameworkInferenceProviderInfo, + FrameworkInfo, + framework_info, + load_framework_info, + save_framework_info, +) +from sparseml.sparsification import SparsificationInfo + + +@pytest.mark.parametrize( + "const_args", + [ + { + "name": "test name", + "description": "test description", + "device": "test device", + }, + { + "name": "test name", + "description": "test description", + "device": "test device", + "supported_sparsification": SparsificationInfo(), + "available": True, + "properties": {"prop_key": "prop_val"}, + "warnings": ["test warning"], + }, + ], +) +def test_framework_inference_provider_info_lifecycle(const_args): + # test construction + info = FrameworkInferenceProviderInfo(**const_args) + assert info, "No object returned for info constructor" + + # test serialization + info_str = info.json() + assert info_str, "No json returned for info" + + # test deserialization + info_reconst = FrameworkInferenceProviderInfo.parse_raw(info_str) + assert info == info_reconst, "Reconstructed does not equal original" + + +@pytest.mark.parametrize( + "const_args", + [ + { + "framework": Framework.unknown, + "package_versions": {"test": "0.1.0"}, + }, + { + "framework": Framework.unknown, + "package_versions": {"test": "0.1.0"}, + "sparsification": SparsificationInfo(), + "inference_providers": [ + FrameworkInferenceProviderInfo( + name="test", description="test", device="test" + ) + ], + "properties": {"test_prop": "val"}, + "training_available": True, + "sparsification_available": True, + "exporting_onnx_available": True, + "inference_available": True, + }, + ], +) +def test_framework_info_lifecycle(const_args): + # test construction + info = FrameworkInfo(**const_args) + assert info, "No object returned for info constructor" + + # test serialization + info_str = info.json() + assert info_str, "No json returned for info" + + # test deserialization + info_reconst = FrameworkInfo.parse_raw(info_str) + assert info == info_reconst, "Reconstructed does not equal original" + + +def test_framework_info(): + # test that unknown raises an exception, + # other frameworks will test in their packages + with pytest.raises(ValueError): + framework_info(Framework.unknown) + + +def test_save_load_framework_info(): + info = FrameworkInfo( + framework=Framework.unknown, package_versions={"unknown": "0.0.1"} + ) + save_framework_info(info) + loaded_json = load_framework_info(info.json()) + assert info == loaded_json + + test_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name + save_framework_info(info, test_path) + loaded_path = load_framework_info(test_path) + assert info == loaded_path diff --git a/tests/sparseml/sparsification/__init__.py b/tests/sparseml/sparsification/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/sparsification/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/sparsification/test_info.py b/tests/sparseml/sparsification/test_info.py new file mode 100644 index 00000000000..cf2a0a5071d --- /dev/null +++ b/tests/sparseml/sparsification/test_info.py @@ -0,0 +1,164 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile + +import pytest + +from sparseml.base import Framework +from sparseml.sparsification import ( + ModifierInfo, + ModifierPropInfo, + ModifierType, + SparsificationInfo, + load_sparsification_info, + save_sparsification_info, + sparsification_info, +) + + +def test_modifier_type(): + assert len(ModifierType) == 6 + assert ModifierType.general + assert ModifierType.training + assert ModifierType.pruning + assert ModifierType.quantization + assert ModifierType.act_sparsity + assert ModifierType.misc + + +@pytest.mark.parametrize( + "const_args", + [ + { + "name": "test name", + "description": "test description", + "type_": "str", + }, + { + "name": "test name", + "description": "test description", + "type_": "str", + "restrictions": ["restriction"], + }, + ], +) +def test_modifier_prop_info_lifecycle(const_args): + # test construction + info = ModifierPropInfo(**const_args) + assert info, "No object returned for info constructor" + + # test serialization + info_str = info.json() + assert info_str, "No json returned for info" + + # test deserialization + info_reconst = ModifierPropInfo.parse_raw(info_str) + assert info == info_reconst, "Reconstructed does not equal original" + + +@pytest.mark.parametrize( + "const_args", + [ + { + "name": "test name", + "description": "test description", + }, + { + "name": "test name", + "description": "test description", + "type_": ModifierType.general, + "props": [ + ModifierPropInfo(name="name", description="description", type_="str") + ], + "warnings": ["warning"], + }, + ], +) +def test_modifier_info_lifecycle(const_args): + # test construction + info = ModifierInfo(**const_args) + assert info, "No object returned for info constructor" + + # test serialization + info_str = info.json() + assert info_str, "No json returned for info" + + # test deserialization + info_reconst = ModifierInfo.parse_raw(info_str) + assert info == info_reconst, "Reconstructed does not equal original" + + +@pytest.mark.parametrize( + "const_args", + [ + {}, + { + "modifiers": [ + ModifierInfo( + name="name", + description="description", + props=[ + ModifierPropInfo( + name="name", description="description", type_="str" + ) + ], + ) + ] + }, + ], +) +def test_sparsification_info_lifecycle(const_args): + # test construction + info = SparsificationInfo(**const_args) + assert info, "No object returned for info constructor" + + # test serialization + info_str = info.json() + assert info_str, "No json returned for info" + + # test deserialization + info_reconst = SparsificationInfo.parse_raw(info_str) + assert info == info_reconst, "Reconstructed does not equal original" + + +def test_sparsification_info(): + # test that unknown raises an exception, + # other sparsifications will test in their packages + with pytest.raises(ValueError): + sparsification_info(Framework.unknown) + + +def test_save_load_sparsification_info(): + info = SparsificationInfo( + modifiers=[ + ModifierInfo( + name="name", + description="description", + props=[ + ModifierPropInfo( + name="name", description="description", type_="str" + ) + ], + ) + ] + ) + save_sparsification_info(info) + loaded_json = load_sparsification_info(info.json()) + assert info == loaded_json + + test_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name + save_sparsification_info(info, test_path) + loaded_path = load_sparsification_info(test_path) + assert info == loaded_path diff --git a/tests/sparseml/test_base.py b/tests/sparseml/test_base.py new file mode 100644 index 00000000000..2f4d825fc86 --- /dev/null +++ b/tests/sparseml/test_base.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import pytest + +from sparseml import __version__ +from sparseml.base import ( + Framework, + check_version, + detect_framework, + execute_in_sparseml_framework, + get_version, +) + + +def test_framework(): + assert len(Framework) == 6 + assert Framework.unknown + assert Framework.deepsparse + assert Framework.onnx + assert Framework.keras + assert Framework.pytorch + assert Framework.tensorflow_v1 + + +@pytest.mark.parametrize( + "inp,expected", + [ + ("unknown", Framework.unknown), + ("deepsparse", Framework.deepsparse), + ("onnx", Framework.onnx), + ("keras", Framework.keras), + ("pytorch", Framework.pytorch), + ("tensorflow_v1", Framework.tensorflow_v1), + (Framework.unknown, Framework.unknown), + (Framework.deepsparse, Framework.deepsparse), + (Framework.onnx, Framework.onnx), + (Framework.keras, Framework.keras), + (Framework.pytorch, Framework.pytorch), + (Framework.tensorflow_v1, Framework.tensorflow_v1), + ], +) +def test_detect_framework(inp: Any, expected: Framework): + detected = detect_framework(inp) + assert detected == expected + + +def test_execute_in_sparseml_framework(): + with pytest.raises(ValueError): + execute_in_sparseml_framework(Framework.unknown, "unknown") + + with pytest.raises(ValueError): + execute_in_sparseml_framework(Framework.onnx, "unknown") + + # TODO: fill in with sample functions to execute in frameworks once available + + +def test_get_version(): + version = get_version("sparseml", raise_on_error=True) + assert version == __version__ + + with pytest.raises(ImportError): + get_version("unknown", raise_on_error=True) + + assert not get_version("unknown", raise_on_error=False) + + +def test_check_version(): + assert check_version("sparseml") + + assert not check_version("sparseml", min_version="10.0.0", raise_on_error=False) + with pytest.raises(ImportError): + check_version("sparseml", min_version="10.0.0") + + assert not check_version("sparseml", max_version="0.0.1", raise_on_error=False) + with pytest.raises(ImportError): + check_version("sparseml", max_version="0.0.1") + + assert not check_version("unknown", raise_on_error=False) + with pytest.raises(ImportError): + check_version("unknown") diff --git a/tests/sparseml/test_imports.py b/tests/sparseml/test_imports.py new file mode 100644 index 00000000000..8ba1e405b1d --- /dev/null +++ b/tests/sparseml/test_imports.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_imports(): + # flake8: noqa + from sparseml import ( + Framework, + FrameworkInferenceProviderInfo, + FrameworkInfo, + SparsificationInfo, + check_version, + detect_framework, + execute_in_sparseml_framework, + framework_info, + get_main_logger, + get_nm_root_logger, + set_logging_level, + sparsification_info, + version, + version_bug, + version_major, + version_major_minor, + version_minor, + ) diff --git a/tests/sparseml/test_version.py b/tests/sparseml/test_version.py new file mode 100644 index 00000000000..5f424048bf6 --- /dev/null +++ b/tests/sparseml/test_version.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.version import ( + __version__, + version, + version_bug, + version_major, + version_major_minor, + version_minor, +) + + +def test_version(): + assert __version__ + assert version + assert version_major == version.split(".")[0] + assert version_minor == version.split(".")[1] + assert version_bug == version.split(".")[2] + assert version_major_minor == f"{version_major}.{version_minor}"