Skip to content

Commit

Permalink
Allow optimum to discover and load subpackages (#1894)
Browse files Browse the repository at this point in the history
As an alternative to directly adding their commands in a register.py file under
the root optimum directory, this adds a decorator to declare a subcommand that
can be used by subpackages when they are loaded.
This will fix the issue of subcommands 'disappearing' when optimum is upgraded
without reinstalling the subpackage.

The onnxruntime commands are moved into a subpackage loader directory.
This subpackage directory is only loaded (and its commands added) when
the onnxruntime is available.
This avoids wrongly indicating that the onnxruntime commands are available
when the package is actually not installed.
  • Loading branch information
dacorvo committed Jun 10, 2024
1 parent 113b645 commit f33f2f1
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 16 deletions.
3 changes: 1 addition & 2 deletions optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimeOptimizeCommand, ONNXRuntimeQuantizeCommand
from .optimum_cli import register_optimum_cli_subcommand
from .optimum_cli import optimum_cli_subcommand
57 changes: 51 additions & 6 deletions optimum/commands/optimum_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,57 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union

from ..subpackages import load_subpackages
from ..utils import logging
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand
from .onnxruntime import ONNXRuntimeCommand


logger = logging.get_logger()

OPTIMUM_CLI_SUBCOMMANDS = [ExportCommand, EnvironmentCommand, ONNXRuntimeCommand]
# The table below contains the optimum-cli root subcommands provided by the optimum package
OPTIMUM_CLI_ROOT_SUBCOMMANDS = [ExportCommand, EnvironmentCommand]

# The table below is dynamically populated when loading subpackages
_OPTIMUM_CLI_SUBCOMMANDS = []


def optimum_cli_subcommand(parent_command: Optional[Type[BaseOptimumCLICommand]] = None):
"""
A decorator to declare optimum-cli subcommands.
The declaration of an optimum-cli subcommand looks like this:
```
@optimum_cli_subcommand()
class MySubcommand(BaseOptimumCLICommand):
<implementation>
```
or
```
@optimum_cli_subcommand(ExportCommand)
class MySubcommand(BaseOptimumCLICommand):
<implementation>
```
Args:
parent_command: (`Optional[Type[BaseOptimumCLICommand]]`):
The class of the parent command or None if this is a top-level command. Defaults to None.
"""

if parent_command is not None and not issubclass(parent_command, BaseOptimumCLICommand):
raise ValueError(f"The parent command {parent_command} must be a subclass of BaseOptimumCLICommand")

def wrapper(subcommand):
if not issubclass(subcommand, BaseOptimumCLICommand):
raise ValueError(f"The subcommand {subcommand} must be a subclass of BaseOptimumCLICommand")
_OPTIMUM_CLI_SUBCOMMANDS.append((subcommand, parent_command))

return wrapper


def resolve_command_to_command_instance(
Expand Down Expand Up @@ -137,15 +178,19 @@ def main():
root = RootOptimumCLICommand("Optimum CLI tool", usage="optimum-cli")
parser = root.parser

for subcommand_cls in OPTIMUM_CLI_SUBCOMMANDS:
for subcommand_cls in OPTIMUM_CLI_ROOT_SUBCOMMANDS:
register_optimum_cli_subcommand(subcommand_cls, parent_command=root)

commands_in_register = dynamic_load_commands_in_register()
# Load subpackages to give them a chance to declare their own subcommands
load_subpackages()

# Register subcommands declared by the subpackages or found in the register files under commands/register
commands_to_register = _OPTIMUM_CLI_SUBCOMMANDS + dynamic_load_commands_in_register()
command2command_instance = resolve_command_to_command_instance(
root, [parent_command_cls for _, parent_command_cls in commands_in_register if parent_command_cls is not None]
root, [parent_command_cls for _, parent_command_cls in commands_to_register if parent_command_cls is not None]
)

for command_or_command_info, parent_command in commands_in_register:
for command_or_command_info, parent_command in commands_to_register:
if parent_command is None:
parent_command_instance = root
else:
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/subpackage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .commands import ONNXRuntimeCommand
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,3 @@
# limitations under the License.

from .base import ONNXRuntimeCommand
from .optimize import ONNXRuntimeOptimizeCommand
from .quantize import ONNXRuntimeQuantizeCommand
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.
"""optimum.onnxruntime command-line interface base classes."""

from .. import BaseOptimumCLICommand, CommandInfo
from optimum.commands import BaseOptimumCLICommand, CommandInfo, optimum_cli_subcommand

from .optimize import ONNXRuntimeOptimizeCommand
from .quantize import ONNXRuntimeQuantizeCommand


@optimum_cli_subcommand()
class ONNXRuntimeCommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(
name="onnxruntime",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_onnxruntime_optimize(parser)

def run(self):
from ...onnxruntime.configuration import AutoOptimizationConfig, ORTConfig
from ...onnxruntime.optimization import ORTOptimizer
from ...configuration import AutoOptimizationConfig, ORTConfig
from ...optimization import ORTOptimizer

if self.args.output == self.args.onnx_model:
raise ValueError("The output directory must be different than the directory hosting the ONNX model.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from typing import TYPE_CHECKING

from .. import BaseOptimumCLICommand
from optimum.commands import BaseOptimumCLICommand


if TYPE_CHECKING:
Expand Down Expand Up @@ -69,8 +69,8 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_onnxruntime_quantize(parser)

def run(self):
from ...onnxruntime.configuration import AutoQuantizationConfig, ORTConfig
from ...onnxruntime.quantization import ORTQuantizer
from ...configuration import AutoQuantizationConfig, ORTConfig
from ...quantization import ORTQuantizer

if self.args.output == self.args.onnx_model:
raise ValueError("The output directory must be different than the directory hosting the ONNX model.")
Expand Down
81 changes: 81 additions & 0 deletions optimum/subpackages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import logging
import sys


if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata
else:
import importlib_metadata
from importlib.util import find_spec, module_from_spec

from .utils import is_onnxruntime_available


logger = logging.getLogger(__name__)


def load_namespace_modules(namespace: str, module: str):
"""Load modules with a specific name inside a namespace
This method operates on namespace packages:
https://packaging.python.org/en/latest/guides/packaging-namespace-packages/
For each package inside the specified `namespace`, it looks for the specified `module` and loads it.
Args:
namespace (`str`):
The namespace containing modules to be loaded.
module (`str`):
The name of the module to load in each namespace package.
"""
for dist in importlib_metadata.distributions():
dist_name = dist.metadata["Name"]
if not dist_name.startswith(f"{namespace}-"):
continue
package_import_name = dist_name.replace("-", ".")
module_import_name = f"{package_import_name}.{module}"
if module_import_name in sys.modules:
# Module already loaded
continue
backend_spec = find_spec(module_import_name)
if backend_spec is None:
continue
try:
imported_module = module_from_spec(backend_spec)
sys.modules[module_import_name] = imported_module
backend_spec.loader.exec_module(imported_module)
logger.debug(f"Successfully loaded {module_import_name}")
except Exception as e:
logger.error(f"An exception occured while loading {module_import_name}: {e}.")


def load_subpackages():
"""Load optimum subpackages
This method goes through packages inside the `optimum` namespace and loads the `subpackage` module if it exists.
This module is then in charge of registering the subpackage commands.
"""
SUBPACKAGE_LOADER = "subpackage"
load_namespace_modules("optimum", SUBPACKAGE_LOADER)

# Load subpackages from internal modules not explicitly defined as namespace packages
loader_name = "." + SUBPACKAGE_LOADER
if is_onnxruntime_available():
importlib.import_module(loader_name, package="optimum.onnxruntime")

0 comments on commit f33f2f1

Please sign in to comment.