Skip to content

Commit

Permalink
module_import_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Mar 20, 2024
1 parent 86e474f commit ae46119
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 77 deletions.
16 changes: 4 additions & 12 deletions api/core/extension/extensible.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import enum
import importlib.util
import json
import logging
import os
from typing import Any, Optional

from pydantic import BaseModel

from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import sort_to_dict_by_position_map


Expand Down Expand Up @@ -73,17 +73,9 @@ def scan_extensions(cls):

# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break

if not extension_class:
try:
extension_class = load_single_subclass_from_source(extension_name, py_path, cls)
except Exception:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue

Expand Down
17 changes: 5 additions & 12 deletions api/core/model_runtime/model_providers/__base/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os
from abc import ABC, abstractmethod

Expand All @@ -7,6 +6,7 @@
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source


class ModelProvider(ABC):
Expand Down Expand Up @@ -104,17 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel:

# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break

mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel)), None)
if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')

Expand Down
15 changes: 5 additions & 10 deletions api/core/model_runtime/model_providers/model_provider_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import logging
import os
from typing import Optional
Expand All @@ -10,6 +9,7 @@
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -229,15 +229,10 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:

# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

model_provider_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
model_provider_class = obj
break
model_provider_class = load_single_subclass_from_source(
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
script_path=py_path,
parent_type=ModelProvider)

if not model_provider_class:
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
Expand Down
17 changes: 6 additions & 11 deletions api/core/tools/provider/builtin_tool_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
from abc import abstractmethod
from os import listdir, path
from typing import Any
Expand All @@ -16,6 +15,7 @@
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.utils.module_import_helper import load_single_subclass_from_source


class BuiltinToolProviderController(ToolProviderController):
Expand Down Expand Up @@ -63,16 +63,11 @@ def _get_builtin_tools(self) -> list[Tool]:
tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader)
# get tool class, import the module
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# get all the classes in the module
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
]
assistant_tool_class = classes[0]
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tools.append(assistant_tool_class(**tool))

self.tools = tools
Expand Down
43 changes: 11 additions & 32 deletions api/core/tools/tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import json
import logging
import mimetypes
Expand Down Expand Up @@ -34,6 +33,7 @@
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict
from core.utils.module_import_helper import load_single_subclass_from_source
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider

Expand Down Expand Up @@ -72,21 +72,11 @@ def invoke(

if provider_entity is None:
# fetch the provider from .provider.builtin
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# get all the classes in the module
classes = [ x for _, x in vars(mod).items()
if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')

provider_entity = classes[0]()
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
parent_type=ToolProviderController)
provider_entity = provider_class()

return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)

Expand Down Expand Up @@ -330,23 +320,12 @@ def list_builtin_providers() -> list[BuiltinToolProviderController]:
if provider.startswith('__'):
continue

py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# load all classes
classes = [
obj for name, obj in vars(mod).items()
if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')

# init provider
provider_class = classes[0]
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())

# cache the builtin providers
Expand Down
61 changes: 61 additions & 0 deletions api/core/utils/module_import_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import importlib.util
import logging
import sys
from types import ModuleType
from typing import AnyStr


def import_module_from_source(
module_name: str,
py_file_path: AnyStr,
use_lazy_loader: bool = False
) -> ModuleType:
"""
Importing a module from the source file directly
"""
try:
existed_spec = importlib.util.find_spec(module_name)
if existed_spec:
spec = existed_spec
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader)
module = importlib.util.module_from_spec(spec)
if not existed_spec:
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
raise e


def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
"""
Get all the subclasses of the parent type from the module
"""
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)]
return classes


def load_single_subclass_from_source(
module_name: str,
script_path: AnyStr,
parent_type: type,
) -> type:
"""
Load a single subclass from the source
"""
module = import_module_from_source(module_name, script_path)
subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses):
case 1:
return subclasses[0]
case 0:
raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}')
case _:
raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}')

0 comments on commit ae46119

Please sign in to comment.