Navigation Menu

Skip to content

Commit

Permalink
reimplement locate and linting
Browse files Browse the repository at this point in the history
this is a reimplementation of the pydoc.locate function.  That function
would attempt to load a module from beginning to end and if any link in
the chain did not load properly then it would fail.  The reimplemented
function searches for the module to load from the end of the path to the
beginning.

Additionally, this reimplementation will raise exceptions instead of
returning None if the path cannot be located.

Signed-off-by: David Pollack <david@da3.net>
  • Loading branch information
dhpollack committed Mar 26, 2020
1 parent e9bc1a8 commit 6ebdefe
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
67 changes: 58 additions & 9 deletions hydra/utils.py
Expand Up @@ -2,7 +2,8 @@
import logging.config
import warnings
from pathlib import Path
from typing import Any, Type
from typing import Any, Type, Optional, cast
from types import ModuleType

from omegaconf import DictConfig, OmegaConf, _utils

Expand All @@ -12,18 +13,66 @@
log = logging.getLogger(__name__)


def _safeimport(path: str) -> Optional[ModuleType]:
"""
Import a module; handle errors; return None if the module isn't found.
This is a typed simplified version of the `pydoc` function `safeimport`.
"""
import sys
from importlib import import_module
try:
module = import_module(path)
except ImportError as e:
if e.name == path:
return None
else:
log.error(f"Error importing module: {path}")
raise e
except Exception as e:
log.error(f"Non-ImportError while importing module {path}: {e}")
raise ValueError(f"Non-ImportError while importing module {path}: {e}")
for part in path.replace(module.__name__, "").split('.'):
if not hasattr(module, part):
break
module = getattr(module, part)
return module


def _locate(path: str) -> ModuleType:
"""
Locate an object by name or dotted path, importing as necessary.
This is similar to the pydoc function `locate`, except that it checks for
the module from the end of the path to the beginning.
"""
parts = [part for part in path.split('.') if part]
module = None
for n in reversed(range(len(parts))):
try:
module = _safeimport('.'.join(parts[:n]))
except:
continue
if module:
break
if module:
obj = module
else:
log.error(f"Module not found: {path}")
raise ValueError(f"Module not found: {path}")
for part in parts[n:]:
if not hasattr(obj, part):
log.error(f"Error finding attribute ({part}) in class ({obj.__name__}): {path}")
raise ValueError(f"Error finding attribute ({part}) in class ({obj.__name__}): {path}")
obj = getattr(obj, part)
return obj


def get_method(path: str) -> type:
return get_class(path)


def get_class(path: str) -> type:
try:
from pydoc import locate

klass = locate(path)
if not klass:
log.error(f"Error finding module in class path {path}")
raise ValueError(f"Error finding module in class path {path}")
klass = cast(type, _locate(path))
return klass
except Exception as e:
log.error(f"Error initializing class {path}")
Expand All @@ -32,7 +81,7 @@ def get_class(path: str) -> type:

def get_static_method(full_method_name: str) -> type:
try:
ret: type = get_class(full_method_name)
ret = get_class(full_method_name)
return ret
except Exception as e:
log.error(f"Error getting static method {full_method_name} : {e}")
Expand Down Expand Up @@ -108,7 +157,7 @@ def _instantiate_class(
params = config.params if "params" in config else OmegaConf.create()
assert isinstance(
params, DictConfig
), f"Input config params are expected to be a mapping, found {type(config.params)}"
), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"
primitives = {}
rest = {}
for k, v in kwargs.items():
Expand Down
16 changes: 11 additions & 5 deletions tests/test_utils.py
Expand Up @@ -61,9 +61,12 @@ def __eq__(self, other: Any) -> Any:

class Baz(Foo):
@classmethod
def class_method(self, y: int) -> None:
def class_method(self, y: int) -> Any:
return self(y + 1)

@staticmethod
def static_method() -> Any:
return 43

@pytest.mark.parametrize( # type: ignore
"path,expected_type", [("tests.test_utils.Bar", Bar)]
Expand Down Expand Up @@ -125,14 +128,17 @@ def test_get_static_method(path: str, return_value: Any) -> None:
Bar(10, 200, 200, 40),
),
(
{
"cls": "tests.test_utils.Baz.class_method",
"params": {"y": 10},
},
{"cls": "tests.test_utils.Baz.class_method", "params": {"y": 10}, },
None,
{},
Baz(11),
),
(
{"cls": "tests.test_utils.Baz.static_method", "params": {}, },
None,
{},
43,
),
# Check that default value is respected
(
{"cls": "tests.test_utils.Bar", "params": {"b": 200, "c": "${params.b}"}},
Expand Down

0 comments on commit 6ebdefe

Please sign in to comment.