Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inference from return type annotations in completer #14357

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
228 changes: 183 additions & 45 deletions IPython/core/guarded_eval.py
@@ -1,16 +1,23 @@
from inspect import signature, Signature
from inspect import isclass, signature, Signature
from typing import (
Any,
Annotated,
AnyStr,
Callable,
Dict,
Literal,
NamedTuple,
NewType,
Optional,
Protocol,
Set,
Sequence,
Tuple,
NamedTuple,
Type,
Literal,
TypeGuard,
Union,
TYPE_CHECKING,
get_args,
get_origin,
is_typeddict,
)
import ast
import builtins
Expand All @@ -21,15 +28,18 @@
from dataclasses import dataclass, field
from types import MethodDescriptorType, ModuleType

from IPython.utils.docs import GENERATING_DOCUMENTATION
from IPython.utils.decorators import undoc


if TYPE_CHECKING or GENERATING_DOCUMENTATION:
from typing_extensions import Protocol
if sys.version_info < (3, 11):
from typing_extensions import Self, LiteralString
else:
from typing import Self, LiteralString

if sys.version_info < (3, 12):
from typing_extensions import TypeAliasType
else:
# do not require on runtime
Protocol = object # requires Python >=3.8
from typing import TypeAliasType


@undoc
Expand Down Expand Up @@ -337,6 +347,7 @@ def __getitem__(self, key):
IDENTITY_SUBSCRIPT = _IdentitySubscript()
SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
UNKNOWN_SIGNATURE = Signature()
NOT_EVALUATED = object()


class GuardRejection(Exception):
Expand Down Expand Up @@ -417,9 +428,37 @@ def guarded_eval(code: str, context: EvaluationContext):
}


class Duck:
class ImpersonatingDuck:
"""A dummy class used to create objects of other classes without calling their ``__init__``"""

# no-op: override __class__ to impersonate


class _Duck:
"""A dummy class used to create objects pretending to have given attributes"""

def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
self.attributes = attributes or {}
self.items = items or {}

def __getattr__(self, attr: str):
return self.attributes[attr]

def __hasattr__(self, attr: str):
return attr in self.attributes

def __dir__(self):
return [*dir(super), *self.attributes]

def __getitem__(self, key: str):
return self.items[key]

def __hasitem__(self, key: str):
return self.items[key]

def _ipython_key_completions_(self):
return self.items.keys()


def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
dunder = None
Expand Down Expand Up @@ -557,19 +596,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
f" not allowed in {context.evaluation} mode",
)
if isinstance(node, ast.Name):
if policy.allow_locals_access and node.id in context.locals:
return context.locals[node.id]
if policy.allow_globals_access and node.id in context.globals:
return context.globals[node.id]
if policy.allow_builtins_access and hasattr(builtins, node.id):
# note: do not use __builtins__, it is implementation detail of cPython
return getattr(builtins, node.id)
if not policy.allow_globals_access and not policy.allow_locals_access:
raise GuardRejection(
f"Namespace access not allowed in {context.evaluation} mode"
)
else:
raise NameError(f"{node.id} not found in locals, globals, nor builtins")
return _eval_node_name(node.id, context)
if isinstance(node, ast.Attribute):
value = eval_node(node.value, context)
if policy.can_get_attr(value, node.attr):
Expand All @@ -590,27 +617,19 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
if policy.can_call(func) and not node.keywords:
args = [eval_node(arg, context) for arg in node.args]
return func(*args)
try:
sig = signature(func)
except ValueError:
sig = UNKNOWN_SIGNATURE
# if annotation was not stringized, or it was stringized
# but resolved by signature call we know the return type
not_empty = sig.return_annotation is not Signature.empty
not_stringized = not isinstance(sig.return_annotation, str)
if not_empty and not_stringized:
duck = Duck()
# if allow-listed builtin is on type annotation, instantiate it
if policy.can_call(sig.return_annotation) and not node.keywords:
args = [eval_node(arg, context) for arg in node.args]
return sig.return_annotation(*args)
try:
# if custom class is in type annotation, mock it;
# this only works for heap types, not builtins
duck.__class__ = sig.return_annotation
return duck
except TypeError:
pass
if isclass(func):
# this code path gets entered when calling class e.g. `MyClass()`
# or `my_instance.__class__()` - in both cases `func` is `MyClass`.
# Should return `MyClass` if `__new__` is not overridden,
# otherwise whatever `__new__` return type is.
overridden_return_type = _eval_return_type(func.__new__, node, context)
if overridden_return_type is not NOT_EVALUATED:
return overridden_return_type
return _create_duck_for_heap_type(func)
else:
return_type = _eval_return_type(func, node, context)
if return_type is not NOT_EVALUATED:
return return_type
raise GuardRejection(
"Call for",
func, # not joined to avoid calling `repr`
Expand All @@ -619,6 +638,125 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
raise ValueError("Unhandled node", ast.dump(node))


def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
"""Evaluate return type of a given callable function.

Returns the built-in type, a duck or NOT_EVALUATED sentinel.
"""
try:
sig = signature(func)
except ValueError:
sig = UNKNOWN_SIGNATURE
# if annotation was not stringized, or it was stringized
# but resolved by signature call we know the return type
not_empty = sig.return_annotation is not Signature.empty
if not_empty:
return _resolve_annotation(sig.return_annotation, sig, func, node, context)
return NOT_EVALUATED


def _resolve_annotation(
annotation,
sig: Signature,
func: Callable,
node: ast.Call,
context: EvaluationContext,
):
"""Resolve annotation created by user with `typing` module and custom objects."""
annotation = (
_eval_node_name(annotation, context)
if isinstance(annotation, str)
else annotation
)
origin = get_origin(annotation)
if annotation is Self and hasattr(func, "__self__"):
return func.__self__
elif origin is Literal:
type_args = get_args(annotation)
if len(type_args) == 1:
return type_args[0]
elif annotation is LiteralString:
return ""
elif annotation is AnyStr:
index = None
for i, (key, value) in enumerate(sig.parameters.items()):
if value.annotation is AnyStr:
index = i
break
if index is not None and index < len(node.args):
return eval_node(node.args[index], context)
elif origin is TypeGuard:
return bool()
elif origin is Union:
attributes = [
attr
for type_arg in get_args(annotation)
for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
]
return _Duck(attributes=dict.fromkeys(attributes))
elif is_typeddict(annotation):
return _Duck(
attributes=dict.fromkeys(dir(dict())),
items={
k: _resolve_annotation(v, sig, func, node, context)
for k, v in annotation.__annotations__.items()
},
)
elif hasattr(annotation, "_is_protocol"):
return _Duck(attributes=dict.fromkeys(dir(annotation)))
elif origin is Annotated:
type_arg = get_args(annotation)[0]
return _resolve_annotation(type_arg, sig, func, node, context)
elif isinstance(annotation, NewType):
return _eval_or_create_duck(annotation.__supertype__, node, context)
elif isinstance(annotation, TypeAliasType):
return _eval_or_create_duck(annotation.__value__, node, context)
else:
return _eval_or_create_duck(annotation, node, context)


def _eval_node_name(node_id: str, context: EvaluationContext):
policy = EVALUATION_POLICIES[context.evaluation]
if policy.allow_locals_access and node_id in context.locals:
return context.locals[node_id]
if policy.allow_globals_access and node_id in context.globals:
return context.globals[node_id]
if policy.allow_builtins_access and hasattr(builtins, node_id):
# note: do not use __builtins__, it is implementation detail of cPython
return getattr(builtins, node_id)
if not policy.allow_globals_access and not policy.allow_locals_access:
raise GuardRejection(
f"Namespace access not allowed in {context.evaluation} mode"
)
else:
raise NameError(f"{node_id} not found in locals, globals, nor builtins")


def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
policy = EVALUATION_POLICIES[context.evaluation]
# if allow-listed builtin is on type annotation, instantiate it
if policy.can_call(duck_type) and not node.keywords:
args = [eval_node(arg, context) for arg in node.args]
return duck_type(*args)
# if custom class is in type annotation, mock it
return _create_duck_for_heap_type(duck_type)


def _create_duck_for_heap_type(duck_type):
"""Create an imitation of an object of a given type (a duck).

Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
"""
duck = ImpersonatingDuck()
try:
# this only works for heap types, not builtins
duck.__class__ = duck_type
return duck
except TypeError:
pass
return NOT_EVALUATED


SUPPORTED_EXTERNAL_GETITEM = {
("pandas", "core", "indexing", "_iLocIndexer"),
("pandas", "core", "indexing", "_LocIndexer"),
Expand Down