Skip to content

Commit

Permalink
Merge pull request #1038 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 committed Oct 28, 2021
2 parents 60e8346 + ab53204 commit 926bdc7
Show file tree
Hide file tree
Showing 19 changed files with 233 additions and 133 deletions.
5 changes: 0 additions & 5 deletions pytype/config.py
Expand Up @@ -154,11 +154,6 @@ def add_basic_options(o):
"invalid function calls."))
temporary = ("This flag is temporary and will be removed once this behavior "
"is enabled by default.")
o.add_argument(
"--bind-properties", action="store_true",
dest="bind_properties", default=False,
help=("Bind @property methods to the classes they're defined on for more "
"precise type-checking. " + temporary))
o.add_argument(
"--use-enum-overlay", action="store_true",
dest="use_enum_overlay", default=False,
Expand Down
5 changes: 4 additions & 1 deletion pytype/convert.py
Expand Up @@ -521,7 +521,7 @@ def _load_late_type(self, late_type):
t = pytd.AnythingType()
else:
try:
cls = pytd_utils.LookupItemRecursive(ast, attr_name)
cls = pytd.LookupItemRecursive(ast, attr_name)
except KeyError:
if "__getattr__" not in ast:
log.warning("Couldn't resolve %s", late_type.name)
Expand Down Expand Up @@ -843,6 +843,9 @@ def _constant_to_value(self, pyval, subst, get_node):
ret = attr_overlay.AttribInstance.from_metadata(
self.ctx, self.ctx.root_node, typ, md)
return ret
elif md["tag"] == "attr.s":
ret = attr_overlay.Attrs.from_metadata(self.ctx, md)
return ret
except (IndexError, ValueError, TypeError, KeyError):
details = "Wrong format for pytype_metadata."
self.ctx.errorlog.invalid_annotation(self.ctx.vm.frames,
Expand Down
14 changes: 13 additions & 1 deletion pytype/load_pytd_test.py
Expand Up @@ -417,7 +417,7 @@ def test_submodule_rename(self):
""")
loader = load_pytd.Loader(None, self.python_version, pythonpath=[d.path])
foo = loader.import_name("foo")
self.assertEqual(pytd_utils.Print(foo), "import foo.bar as foo.baz")
self.assertEqual(pytd_utils.Print(foo), "from foo import bar as foo.baz")

def test_typing_reexport(self):
with file_utils.Tempdir() as d:
Expand Down Expand Up @@ -478,6 +478,18 @@ class Bar(Foo): ...
loader = load_pytd.Loader(None, self.python_version, pythonpath=[d.path])
loader.import_name("foo.bar")

def test_module_alias(self):
ast = self._import(foo="""
import subprocess as _subprocess
x: _subprocess.Popen
""")
expected = textwrap.dedent("""
import subprocess as foo._subprocess
foo.x: _subprocess.Popen
""").strip()
self.assertMultiLineEqual(pytd_utils.Print(ast), expected)


class ImportTypeMacroTest(_LoaderTest):

Expand Down
4 changes: 4 additions & 0 deletions pytype/output.py
Expand Up @@ -359,6 +359,10 @@ def value_to_pytd_def(self, node, v, name):
assert isinstance(d, pytd.Function)
sigs = tuple(sig.Replace(params=sig.params[1:]) for sig in d.signatures)
return d.Replace(signatures=sigs)
elif isinstance(v, attr_overlay.Attrs):
ret = pytd.NamedType("typing.Callable")
md = metadata.to_pytd(v.to_metadata())
return pytd.Annotated(ret, ("'pytype_metadata'", md))
elif (isinstance(v, abstract.PyTDFunction) and
not isinstance(v, typing_overlay.TypeVar)):
return pytd.Function(
Expand Down
15 changes: 15 additions & 0 deletions pytype/overlays/attr_overlay.py
Expand Up @@ -184,6 +184,21 @@ def decorate(self, node, cls):
# Fix up type parameters in methods added by the decorator.
cls.update_method_type_params()

def to_metadata(self):
return {
"tag": "attr.s",
"init": self._current_args["init"],
"kw_only": self._current_args["kw_only"],
"auto_attribs": self._current_args["auto_attribs"]
}

@classmethod
def from_metadata(cls, ctx, metadata):
kwargs = {k: metadata[k] for k in ("init", "kw_only", "auto_attribs")}
ret = cls.make(ctx)
ret.set_current_args(kwargs)
return ret


class AttrsNextGenDefine(Attrs):
"""Implements the @attr.define decorator.
Expand Down
6 changes: 6 additions & 0 deletions pytype/overlays/classgen.py
Expand Up @@ -70,6 +70,7 @@ def decorate(self, node, cls):
"""Apply the decorator to cls."""

def update_kwargs(self, args):
"""Update current_args with the Args passed to the decorator."""
self._current_args = self._DEFAULT_ARGS.copy()
for k, v in args.namedargs.items():
if k in self._current_args:
Expand All @@ -79,6 +80,11 @@ def update_kwargs(self, args):
self.ctx.errorlog.not_supported_yet(
self.ctx.vm.frames, "Non-constant argument to decorator: %r" % k)

def set_current_args(self, kwargs):
"""Set current_args when constructing a class directly."""
self._current_args = self._DEFAULT_ARGS.copy()
self._current_args.update(kwargs)

def init_name(self, attr):
"""Attribute name as an __init__ keyword, could differ from attr.name."""
return attr.name
Expand Down
2 changes: 1 addition & 1 deletion pytype/pyi/definitions.py
Expand Up @@ -131,7 +131,7 @@ def _maybe_resolve_alias(alias, name_to_class, name_to_constant):
return pytd.Constant(
alias.name, pytdgen.pytd_type(pytd.NamedType(alias.type.name)))
elif isinstance(value, pytd.Function):
return pytd_utils.AliasMethod(
return pytd.AliasMethod(
value.Replace(name=alias.name),
from_constant=isinstance(prev_value, pytd.Constant))
else:
Expand Down
1 change: 1 addition & 0 deletions pytype/pytd/CMakeLists.txt
Expand Up @@ -193,6 +193,7 @@ py_library(
._pytd
.base_visitor
.pep484
pytype.utils
pytype.pytd.parse.parse
)

Expand Down
81 changes: 54 additions & 27 deletions pytype/pytd/printer.py
Expand Up @@ -5,6 +5,7 @@
import logging
import re

from pytype import utils
from pytype.pytd import base_visitor
from pytype.pytd import pep484
from pytype.pytd import pytd
Expand All @@ -27,11 +28,13 @@ def __init__(self, multiline_args=False):
self.in_alias = False
self.in_parameter = False
self.in_literal = False
self._unit_name = None
self.multiline_args = multiline_args

self._unit = None
self._local_names = {}
self._class_members = set()
self._typing_import_counts = collections.defaultdict(int)
self.multiline_args = multiline_args
self._module_aliases = {}

def Print(self, node):
return node.Visit(copy.deepcopy(self))
Expand Down Expand Up @@ -107,7 +110,8 @@ def _FormatTypeParams(self, type_params):
def _NameCollision(self, name):

def name_in(members):
return name in members or f"{self._unit_name}.{name}" in members
return name in members or (
self._unit and f"{self._unit.name}.{name}" in members)

return name_in(self._class_members) or name_in(self._local_names)

Expand All @@ -130,18 +134,30 @@ def _ImportTypingExtension(self, name):
else:
return self._FromTyping(name)

def _StripUnitPrefix(self, name):
if self._unit:
return utils.strip_prefix(name, f"{self._unit.name}.")
else:
return name

def EnterTypeDeclUnit(self, unit):
self._unit_name = unit.name
self._unit = unit
for definitions, label in [(unit.classes, "class"),
(unit.functions, "function"),
(unit.constants, "constant"),
(unit.type_params, "type_param"),
(unit.aliases, "alias")]:
for defn in definitions:
self._local_names[defn.name] = label
for alias in unit.aliases:
if isinstance(alias.type, pytd.Module):
name = alias.name
if unit.name and name.startswith(unit.name + "."):
name = name[len(unit.name) + 1:]
self._module_aliases[alias.type.module_name] = name

def LeaveTypeDeclUnit(self, _):
self._unit_name = None
self._unit = None
self._local_names = {}

def VisitTypeDeclUnit(self, node):
Expand Down Expand Up @@ -197,9 +213,7 @@ def VisitAlias(self, node):
suffix = ""
module, _, name = full_name.rpartition(".")
if module:
alias_name = self.old_node.name
if alias_name.startswith(f"{self._unit_name}."):
alias_name = alias_name[len(self._unit_name)+1:]
alias_name = self._StripUnitPrefix(self.old_node.name)
if name not in ("*", alias_name):
suffix += f" as {alias_name}"
self.imports = self.old_imports # undo unnecessary imports change
Expand Down Expand Up @@ -395,31 +409,39 @@ def VisitTemplateItem(self, node):
"""Convert a template to a string."""
return node.type_param

def _UseExistingModuleAlias(self, name):
prefix, suffix = name.rsplit(".", 1)
while prefix:
if prefix in self._module_aliases:
return f"{self._module_aliases[prefix]}.{suffix}"
prefix, _, remainder = prefix.rpartition(".")
suffix = f"{remainder}.{suffix}"
return None

def VisitNamedType(self, node):
"""Convert a type to a string."""
prefix, _, suffix = node.name.rpartition(".")
if self._IsBuiltin(prefix) and not self._NameCollision(suffix):
node_name = suffix
elif prefix == "typing":
node_name = self._FromTyping(suffix)
elif (prefix and
prefix != self._unit_name and
prefix not in self._local_names):
if self.class_names and "." in self.class_names[-1]:
# We've already fully qualified the class names.
class_prefix = self.class_names[-1]
else:
class_prefix = ".".join(self.class_names)
if self._unit_name:
class_prefixes = {class_prefix, f"{self._unit_name}.{class_prefix}"}
else:
class_prefixes = {class_prefix}
if prefix not in class_prefixes:
# If the prefix doesn't match the class scope, then it's an import.
self._RequireImport(prefix)
elif "." not in node.name:
node_name = node.name
else:
node_name = node.name
if self._unit:
try:
pytd.LookupItemRecursive(self._unit, self._StripUnitPrefix(node.name))
except KeyError:
aliased_name = self._UseExistingModuleAlias(node.name)
if aliased_name:
node_name = aliased_name
else:
self._RequireImport(prefix)
node_name = node.name
else:
node_name = node.name
else:
node_name = node.name
if node_name == "NoneType":
# PEP 484 allows this special abbreviation.
return "None"
Expand Down Expand Up @@ -449,10 +471,15 @@ def VisitTypeParameter(self, node):
return node.name

def VisitModule(self, node):
if node.is_aliased:
return f"import {node.module_name} as {node.name}"
else:
if not node.is_aliased:
return f"import {node.module_name}"
elif "." in node.module_name:
# `import x.y as z` and `from x import y as z` are equivalent, but the
# latter is a bit prettier.
prefix, suffix = node.module_name.rsplit(".", 1)
return f"from {prefix} import {suffix} as {node.name}"
else:
return f"import {node.module_name} as {node.name}"

def MaybeCapitalize(self, name):
"""Capitalize a generic type, if necessary."""
Expand Down
91 changes: 91 additions & 0 deletions pytype/pytd/pytd.py
Expand Up @@ -676,3 +676,94 @@ def ToType(item, allow_constants=False, allow_functions=False,
elif isinstance(item, Alias):
return item.type
raise NotImplementedError("Can't convert %s: %s" % (type(item), item))


def AliasMethod(func, from_constant):
"""Returns method func with its signature modified as if it has been aliased.
Args:
func: A pytd.Function.
from_constant: If True, func will be modified as if it has been aliased from
an instance of its defining class, e.g.,
class Foo:
def func(self): ...
const = ... # type: Foo
func = const.func
Otherwise, it will be modified as if aliased from the class itself:
class Foo:
def func(self): ...
func = Foo.func
Returns:
A pytd.Function, the aliased method.
"""
# We allow module-level aliases of methods from classes and class instances.
# When a static method is aliased, or a normal method is aliased from a class
# (not an instance), the entire method signature is copied. Otherwise, the
# first parameter ('self' or 'cls') is dropped.
new_func = func.Replace(kind=MethodTypes.METHOD)
if func.kind == MethodTypes.STATICMETHOD or (
func.kind == MethodTypes.METHOD and not from_constant):
return new_func
return new_func.Replace(signatures=tuple(
s.Replace(params=s.params[1:]) for s in new_func.signatures))


def LookupItemRecursive(module, name):
"""Recursively look up name in module."""
parts = name.split('.')
partial_name = module.name
prev_item = None
item = module

def ExtractClass(t):
if isinstance(t, ClassType):
return t.cls
t = module.Lookup(t.name) # may raise KeyError
if isinstance(t, Class):
return t
raise KeyError(t.name)

for part in parts:
prev_item = item
# Check the type of item and give up if we encounter a type we don't know
# how to handle.
if isinstance(item, Constant):
item = ExtractClass(item.type) # may raise KeyError
elif not isinstance(item, (TypeDeclUnit, Class)):
raise KeyError(name)
lookup_name = partial_name + '.' + part

def Lookup(item, *names):
for name in names:
try:
return item.Lookup(name)
except KeyError:
continue
raise KeyError(names[-1])

# Nested class names are fully qualified while function names are not, so
# we try lookup for both naming conventions.
try:
item = Lookup(item, lookup_name, part)
except KeyError:
if not isinstance(item, Class):
raise
for parent in item.parents:
parent_cls = ExtractClass(parent) # may raise KeyError
try:
item = Lookup(parent_cls, lookup_name, part)
except KeyError:
continue # continue up the MRO
else:
break # name found!
else:
raise # unresolved
if isinstance(item, Constant):
partial_name += '.' + item.name.rsplit('.', 1)[-1]
else:
partial_name = lookup_name
if isinstance(item, Function):
return AliasMethod(item, from_constant=isinstance(prev_item, Constant))
else:
return item

0 comments on commit 926bdc7

Please sign in to comment.