Skip to content

Commit

Permalink
feat(common): support Self annotations for Annotable
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Oct 9, 2023
1 parent 8ed313c commit 0c60146
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ibis/common/grounds.py
Expand Up @@ -50,7 +50,7 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
annotations = dct.get("__annotations__", {})

# TODO(kszucs): pass dct as localns to evaluate_annotations
typehints = evaluate_annotations(annotations, module)
typehints = evaluate_annotations(annotations, module, clsname)
for name, typehint in typehints.items():
if get_origin(typehint) is ClassVar:
continue
Expand Down
4 changes: 3 additions & 1 deletion ibis/common/patterns.py
Expand Up @@ -119,9 +119,11 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern:
elif isinstance(annot, Enum):
# for enums we check the value against the enum values
return EqualTo(annot)
elif isinstance(annot, (str, ForwardRef)):
elif isinstance(annot, str):
# for strings and forward references we check in a lazy way
return LazyInstanceOf(annot)
elif isinstance(annot, ForwardRef):
return LazyInstanceOf(annot.__forward_arg__)
else:
raise TypeError(f"Cannot create validator from annotation {annot!r}")
elif origin is CoercedTo:
Expand Down
19 changes: 19 additions & 0 deletions ibis/common/tests/test_graph_benchmarks.py
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from ibis.common.collections import frozendict # noqa: TCH001
from ibis.common.graph import Node
from ibis.common.grounds import Concrete

if TYPE_CHECKING:
from typing_extensions import Self


class MyNode(Node, Concrete):
a: int
b: str
c: tuple[int, ...]
d: frozendict[str, int]
e: Self
f: tuple[Self, ...]
20 changes: 19 additions & 1 deletion ibis/common/tests/test_grounds.py
Expand Up @@ -5,7 +5,7 @@
import sys
import weakref
from abc import ABCMeta
from typing import Callable, Generic, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union

import pytest

Expand Down Expand Up @@ -42,6 +42,9 @@
)
from ibis.tests.util import assert_pickle_roundtrip

if TYPE_CHECKING:
from typing_extensions import Self

is_any = InstanceOf(object)
is_bool = InstanceOf(bool)
is_float = InstanceOf(float)
Expand Down Expand Up @@ -314,6 +317,21 @@ class Op2(Annotable):
Op2()


class RecursiveNode(Annotable):
child: Optional[Self] = None


def test_annotable_with_self_typehint() -> None:
node = RecursiveNode(RecursiveNode(RecursiveNode(None)))
assert isinstance(node, RecursiveNode)
assert isinstance(node.child, RecursiveNode)
assert isinstance(node.child.child, RecursiveNode)
assert node.child.child.child is None

with pytest.raises(ValidationError):
RecursiveNode(1)


def test_annotable_with_recursive_generic_type_annotations():
# testing cons list
pattern = Pattern.from_typehint(List[Integer])
Expand Down
13 changes: 10 additions & 3 deletions ibis/common/tests/test_typing.py
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Generic, Optional, Union
from typing import ForwardRef, Generic, Optional, Union

from typing_extensions import TypeVar

Expand Down Expand Up @@ -41,11 +41,18 @@ def example(a: int, b: str) -> str: # type: ignore


def test_evaluate_annotations() -> None:
annotations = {"a": "Union[int, str]", "b": "Optional[str]"}
hints = evaluate_annotations(annotations, module_name=__name__)
annots = {"a": "Union[int, str]", "b": "Optional[str]"}
hints = evaluate_annotations(annots, module_name=__name__)
assert hints == {"a": Union[int, str], "b": Optional[str]}


def test_evaluate_annotations_with_self() -> None:
annots = {"a": "Union[int, Self]", "b": "Optional[Self]"}
myhint = ForwardRef(f"{__name__}.My")
hints = evaluate_annotations(annots, module_name=__name__, class_name="My")
assert hints == {"a": Union[int, myhint], "b": Optional[myhint]}


def test_get_type_hints() -> None:
hints = get_type_hints(My)
assert hints == {"a": T, "b": S, "c": str}
Expand Down
13 changes: 10 additions & 3 deletions ibis/common/typing.py
Expand Up @@ -167,7 +167,9 @@ def get_bound_typevars(obj: Any) -> dict[TypeVar, tuple[str, type]]:


def evaluate_annotations(
annots: dict[str, str], module_name: str, localns: Optional[Namespace] = None
annots: dict[str, str],
module_name: str,
class_name: Optional[str] = None,
) -> dict[str, Any]:
"""Evaluate type annotations that are strings.
Expand All @@ -178,8 +180,9 @@ def evaluate_annotations(
module_name
The name of the module that the annotations are defined in, hence
providing global scope.
localns
The local namespace to use for evaluation.
class_name
The name of the class that the annotations are defined in, hence
providing Self type.
Returns
-------
Expand All @@ -193,6 +196,10 @@ def evaluate_annotations(
"""
module = sys.modules.get(module_name, None)
globalns = getattr(module, "__dict__", None)
if class_name is None:
localns = None
else:
localns = dict(Self=f"{module_name}.{class_name}")
return {
k: eval(v, globalns, localns) if isinstance(v, str) else v # noqa: PGH001
for k, v in annots.items()
Expand Down

0 comments on commit 0c60146

Please sign in to comment.