diff --git a/README.md b/README.md index 4de5e6a..2d7d89d 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@ pip install triad ## Release History +### 0.5.0 + +* Fix to_type on full type path + ### 0.4.9 * Fix numpy warning diff --git a/tests/utils/test_convert.py b/tests/utils/test_convert.py index a1c609e..3e46f9a 100644 --- a/tests/utils/test_convert.py +++ b/tests/utils/test_convert.py @@ -22,6 +22,8 @@ to_timedelta, to_type, ) +import urllib # must keep for testing purpose +import urllib.request # must keep for testing purpose def test_to_size(): @@ -94,6 +96,24 @@ def test_str_to_type(): assert RuntimeError == str_to_type("RuntimeError", Exception) raises(TypeError, lambda: str_to_type("RuntimeError", int)) + # test a full type path that only root was imported + str_to_type("urllib.request.OpenerDirector") + + # test a full type path that was never imported + str_to_type("shutil.Error") + str_to_type("http.HTTPStatus") + + # class and subclass + class T(object): + def __init__(self): + self.x = 10 + + class _TS(object): + pass + + assert T == str_to_type("T") + assert T._TS == str_to_type("T._TS") + def test_str_to_instance(): i = str_to_instance("tests.utils.Class2") diff --git a/triad/utils/convert.py b/triad/utils/convert.py index 24733a0..99a18a5 100644 --- a/triad/utils/convert.py +++ b/triad/utils/convert.py @@ -1,7 +1,7 @@ import datetime import importlib import inspect -from importlib import util as importlib_util +from types import ModuleType from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -96,15 +96,14 @@ def __init__(self, x=1): _globals, _locals = get_caller_global_local_vars(global_vars, local_vars) if "." not in expr: return eval(expr, _globals, _locals) - root = expr.split(".")[0] - if root not in _globals and root not in _locals: - spec = importlib_util.find_spec(root) - assert_or_throw(spec is not None, ValueError(expr)) - _locals = dict(_locals) - _locals[root] = importlib.import_module(root) - return eval(expr, _globals, _locals) + parts = expr.split(".") + v = _locals.get(parts[0], _globals.get(parts[0], None)) + if v is not None and not isinstance(v, ModuleType): + return eval(expr, _globals, _locals) + root = ".".join(parts[:-1]) + return getattr(importlib.import_module(root), parts[-1]) except ValueError: - raise + raise # pragma: no cover except Exception: raise ValueError(expr) diff --git a/triad_version/__init__.py b/triad_version/__init__.py index 574c066..3d18726 100644 --- a/triad_version/__init__.py +++ b/triad_version/__init__.py @@ -1 +1 @@ -__version__ = "0.4.9" +__version__ = "0.5.0"