diff --git a/hydra/utils.py b/hydra/utils.py index c982c78760..edc2442fd0 100644 --- a/hydra/utils.py +++ b/hydra/utils.py @@ -2,7 +2,8 @@ import logging.config import warnings from pathlib import Path -from typing import Any, Type +from typing import Any, Type, Optional, cast +from types import ModuleType from omegaconf import DictConfig, OmegaConf, _utils @@ -12,18 +13,66 @@ log = logging.getLogger(__name__) +def _safeimport(path: str) -> Optional[ModuleType]: + """ + Import a module; handle errors; return None if the module isn't found. + This is a typed simplified version of the `pydoc` function `safeimport`. + """ + import sys + from importlib import import_module + try: + module = import_module(path) + except ImportError as e: + if e.name == path: + return None + else: + log.error(f"Error importing module: {path}") + raise e + except Exception as e: + log.error(f"Non-ImportError while importing module {path}: {e}") + raise ValueError(f"Non-ImportError while importing module {path}: {e}") + for part in path.replace(module.__name__, "").split('.'): + if not hasattr(module, part): + break + module = getattr(module, part) + return module + + +def _locate(path: str) -> ModuleType: + """ + Locate an object by name or dotted path, importing as necessary. + This is similar to the pydoc function `locate`, except that it checks for + the module from the end of the path to the beginning. + """ + parts = [part for part in path.split('.') if part] + module = None + for n in reversed(range(len(parts))): + try: + module = _safeimport('.'.join(parts[:n])) + except: + continue + if module: + break + if module: + obj = module + else: + log.error(f"Module not found: {path}") + raise ValueError(f"Module not found: {path}") + for part in parts[n:]: + if not hasattr(obj, part): + log.error(f"Error finding attribute ({part}) in class ({obj.__name__}): {path}") + raise ValueError(f"Error finding attribute ({part}) in class ({obj.__name__}): {path}") + obj = getattr(obj, part) + return obj + + def get_method(path: str) -> type: return get_class(path) def get_class(path: str) -> type: try: - from pydoc import locate - - klass = locate(path) - if not klass: - log.error(f"Error finding module in class path {path}") - raise ValueError(f"Error finding module in class path {path}") + klass = cast(type, _locate(path)) return klass except Exception as e: log.error(f"Error initializing class {path}") @@ -32,7 +81,7 @@ def get_class(path: str) -> type: def get_static_method(full_method_name: str) -> type: try: - ret: type = get_class(full_method_name) + ret = get_class(full_method_name) return ret except Exception as e: log.error(f"Error getting static method {full_method_name} : {e}") @@ -108,7 +157,7 @@ def _instantiate_class( params = config.params if "params" in config else OmegaConf.create() assert isinstance( params, DictConfig - ), f"Input config params are expected to be a mapping, found {type(config.params)}" + ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" primitives = {} rest = {} for k, v in kwargs.items(): diff --git a/tests/test_utils.py b/tests/test_utils.py index f996a6ad55..1d4366771f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -61,9 +61,12 @@ def __eq__(self, other: Any) -> Any: class Baz(Foo): @classmethod - def class_method(self, y: int) -> None: + def class_method(self, y: int) -> Any: return self(y + 1) + @staticmethod + def static_method() -> Any: + return 43 @pytest.mark.parametrize( # type: ignore "path,expected_type", [("tests.test_utils.Bar", Bar)] @@ -125,14 +128,17 @@ def test_get_static_method(path: str, return_value: Any) -> None: Bar(10, 200, 200, 40), ), ( - { - "cls": "tests.test_utils.Baz.class_method", - "params": {"y": 10}, - }, + {"cls": "tests.test_utils.Baz.class_method", "params": {"y": 10}, }, None, {}, Baz(11), ), + ( + {"cls": "tests.test_utils.Baz.static_method", "params": {}, }, + None, + {}, + 43, + ), # Check that default value is respected ( {"cls": "tests.test_utils.Bar", "params": {"b": 200, "c": "${params.b}"}},