Skip to content

Commit

Permalink
Merge pull request #1206 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 committed May 17, 2022
2 parents 5b617a5 + 6d02889 commit 559970f
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 92 deletions.
1 change: 1 addition & 0 deletions pytype/abstract/CMakeLists.txt
Expand Up @@ -219,6 +219,7 @@ py_library(
.abstract_utils
pytype.utils
pytype.pytd.pytd
pytype.typegraph.cfg_utils
)

py_library(
Expand Down
13 changes: 3 additions & 10 deletions pytype/abstract/_interpreter_function.py
Expand Up @@ -172,7 +172,8 @@ def _map_args(self, node, args):
kwnames = set(kws)
extra_kws = kwnames.difference(sig.param_names + sig.kwonly_params)
if extra_kws and not sig.kwargs_name:
raise function.WrongKeywordArgs(sig, args, self.ctx, extra_kws)
if function.has_visible_namedarg(node, args, extra_kws):
raise function.WrongKeywordArgs(sig, args, self.ctx, extra_kws)
posonly_kws = kwnames & posonly_names
# If a function has a **kwargs parameter, then keyword arguments with the
# same name as a positional-only argument are allowed, e.g.:
Expand Down Expand Up @@ -235,14 +236,6 @@ def _match_view(self, node, args, view, alias_map=None):
return subst

def _match_args_sequentially(self, node, args, alias_map, match_all_views):
def match_succeeded(match_result):
bad_matches, any_match = match_result
if not bad_matches:
return True
if match_all_views or self.ctx.options.strict_parameter_checks:
return False
return any_match

for name, arg, formal in self.signature.iter_args(args):
if formal is None:
continue
Expand All @@ -251,7 +244,7 @@ def match_succeeded(match_result):
# Iterable or Mapping.
formal = self.ctx.convert.widen_type(formal)
match_result = self.ctx.matcher(node).bad_matches(arg, formal)
if not match_succeeded(match_result):
if not function.match_succeeded(match_result, match_all_views, self.ctx):
bad_arg = function.BadParam(
name=name, expected=formal, error_details=match_result[0][0][1])
raise function.WrongArgTypes(
Expand Down
205 changes: 152 additions & 53 deletions pytype/abstract/_pytd_function.py

Large diffs are not rendered by default.

30 changes: 26 additions & 4 deletions pytype/abstract/function.py
Expand Up @@ -8,6 +8,7 @@
from pytype.abstract import abstract_utils
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.typegraph import cfg_utils

log = logging.getLogger(__name__)
_isinstance = abstract_utils._isinstance # pylint: disable=protected-access
Expand Down Expand Up @@ -250,18 +251,18 @@ def has_param(self, name):
return name in self.param_names or name in self.kwonly_params or (
name == self.varargs_name or name == self.kwargs_name)

def insert_varargs_and_kwargs(self, arg_dict):
"""Insert varargs and kwargs from arg_dict into the signature.
def insert_varargs_and_kwargs(self, args):
"""Insert varargs and kwargs from args into the signature.
Args:
arg_dict: A name->binding dictionary of passed args.
args: An iterable of passed arg names.
Returns:
A copy of this signature with the passed varargs and kwargs inserted.
"""
varargs_names = []
kwargs_names = []
for name in arg_dict:
for name in args:
if self.has_param(name):
continue
if pytd_utils.ANON_PARAM.match(name):
Expand Down Expand Up @@ -884,3 +885,24 @@ def match_all_args(ctx, node, func, args):
needs_checking = False

return args, errors


def match_succeeded(match_result, match_all_views, ctx):
bad_matches, any_match = match_result
if not bad_matches:
return True
if match_all_views or ctx.options.strict_parameter_checks:
return False
return any_match


def has_visible_namedarg(node, args, names):
# Note: this method should be called judiciously, as HasCombination is
# potentially very expensive.
namedargs = {args.namedargs[name] for name in names}
variables = [v for v in args.get_variables() if v not in namedargs]
for name in names:
for view in cfg_utils.variable_product(variables + [args.namedargs[name]]):
if node.HasCombination(list(view)):
return True
return False
6 changes: 5 additions & 1 deletion pytype/load_pytd.py
Expand Up @@ -624,13 +624,17 @@ def finish_and_verify_ast(self, mod_ast):
if mod_ast:
try:
self._resolver.verify(mod_ast)
except (BadDependencyError, visitors.ContainerError):
except (BadDependencyError, visitors.ContainerError) as e:
# In the case of a circular import, an external type may be left
# unresolved, so we re-resolve lookups in this module and its direct
# dependencies. Technically speaking, we should re-resolve all
# transitive imports, but lookups are expensive.
dependencies = self._resolver.collect_dependencies(mod_ast)
for k in dependencies:
if k not in self._modules:
raise (
BadDependencyError("Can't find pyi for %r" % k, mod_ast.name)
) from e
self._modules[k].ast = self._resolve_external_types(
self._modules[k].ast)
mod_ast = self._resolve_external_types(mod_ast)
Expand Down
3 changes: 2 additions & 1 deletion pytype/overlays/subprocess_overlay.py
Expand Up @@ -89,5 +89,6 @@ def _yield_matching_signatures(self, node, args, view, alias_map):
node, args, view, alias_map):
yield sig_info
return
arg_dict, subst = sig.substitute_formal_args(node, args, view, alias_map)
arg_dict, subst = sig.substitute_formal_args_old(
node, args, view, alias_map)
yield sig, arg_dict, subst
19 changes: 13 additions & 6 deletions pytype/pyi/parser.py
Expand Up @@ -383,15 +383,18 @@ def visit_AsyncFunctionDef(self, node):
self._preprocess_function(node)
return function.NameAndSig.from_function(node, True)

def _read_str_list(self, name, value):
if not (isinstance(value, (ast3.List, ast3.Tuple)) and
all(types.Pyval.is_str(x) for x in value.elts)):
raise ParseError(f"{name} must be a list of strings")
return tuple(x.value for x in value.elts)

def new_alias_or_constant(self, name, value):
"""Build an alias or constant."""
# This is here rather than in _Definitions because we need to build a
# constant or alias from a partially converted typed_ast subtree.
if name == "__slots__":
if not (isinstance(value, ast3.List) and
all(types.Pyval.is_str(x) for x in value.elts)):
raise ParseError("__slots__ must be a list of strings")
return types.SlotDecl(tuple(x.value for x in value.elts))
return types.SlotDecl(self._read_str_list(name, value))
elif isinstance(value, types.Pyval):
return pytd.Constant(name, value.to_pytd())
elif isinstance(value, types.Ellipsis):
Expand All @@ -402,11 +405,15 @@ def new_alias_or_constant(self, name, value):
elif isinstance(value, ast3.List):
if name != "__all__":
raise ParseError("Only __slots__ and __all__ can be literal lists")
return pytd.Constant(name, pytdgen.pytd_list("str"))
pyval = self._read_str_list(name, value)
return pytd.Constant(name, pytdgen.pytd_list("str"), pyval)
elif isinstance(value, ast3.Tuple):
pyval = None
if name == "__all__":
pyval = self._read_str_list(name, value)
# TODO(mdemello): Consistent with the current parser, but should it
# properly be Tuple[Type]?
return pytd.Constant(name, pytd.NamedType("tuple"))
return pytd.Constant(name, pytd.NamedType("tuple"), pyval)
elif isinstance(value, ast3.Name):
value = self.defs.resolve_type(value.id)
return pytd.Alias(name, value)
Expand Down
2 changes: 1 addition & 1 deletion pytype/pyi/parser_test.py
Expand Up @@ -530,7 +530,7 @@ def test_all(self):
""", """
from typing import List
__all__: List[str]
__all__: List[str] = ...
""")

def test_invalid_constructor(self):
Expand Down
13 changes: 12 additions & 1 deletion pytype/pytd/visitors.py
Expand Up @@ -514,10 +514,21 @@ def _ImportAll(self, module):
getattrs = set()
ast = self._module_map[module]
type_param_names = set()
if module == "http.client":
# NOTE: http.client adds symbols to globals() at runtime, which is not
# reflected in its typeshed pyi file. The simplest fix is to ignore
# __all__ for star-imports of that file.
exports = None
else:
exports = [x for x in ast.constants if x.name.endswith(".__all__")]
if exports:
exports = exports[0].value
for member in sum((ast.constants, ast.type_params, ast.classes,
ast.functions, ast.aliases), ()):
_, _, member_name = member.name.rpartition(".")
if member_name == "__all__":
if exports and member_name not in exports:
# Not considering the edge case `__all__ = []` since that makes no
# sense in practice.
continue
new_name = self._ModulePrefix() + member_name
if isinstance(member, pytd.Function) and member_name == "__getattr__":
Expand Down
33 changes: 33 additions & 0 deletions pytype/tests/test_functions2.py
Expand Up @@ -443,6 +443,39 @@ def h(x):
return g(x, *f())
""")

def test_namedargs_split(self):
self.Check("""
def f(x):
pass
def g(y):
pass
def h():
kws = {}
if __random__:
kws['x'] = 0
f(**kws)
else:
kws['y'] = 0
g(**kws)
""")

def test_namedargs_split_pyi(self):
with self.DepTree([("foo.pyi", """
def f(x): ...
def g(y): ...
""")]):
self.Check("""
import foo
def h():
kws = {}
if __random__:
kws['x'] = 0
foo.f(**kws)
else:
kws['y'] = 0
foo.g(**kws)
""")


class TestFunctionsPython3Feature(test_base.BaseTest):
"""Tests for functions."""
Expand Down
4 changes: 2 additions & 2 deletions pytype/tests/test_protocol_inference.py
Expand Up @@ -134,10 +134,10 @@ def f(x):
return a & x
""")
self.assertTypesMatchPytd(ty, """
from typing import Iterable, Set
from typing import Set
a = ... # type: Set[int]
def f(x: Iterable) -> Set[int]: ...
def f(x) -> Set[int]: ...
""")

def test_supports_lower(self):
Expand Down
13 changes: 13 additions & 0 deletions pytype/tests/test_pyi1.py
Expand Up @@ -898,6 +898,19 @@ class X:
import bad # pyi-error
""", pythonpath=[d.path])

def test_nonexistent_import(self):
with file_utils.Tempdir() as d:
d.create_file("bad.pyi", """
import nonexistent
x = nonexistent.x
""")
err = self.CheckWithErrors("""
import bad # pyi-error[e]
""", pythonpath=[d.path])
self.assertErrorSequences(err, {
"e": ["Couldn't import pyi", "nonexistent", "referenced from", "bad"]
})


if __name__ == "__main__":
test_base.main()
28 changes: 28 additions & 0 deletions pytype/tests/test_pyi2.py
Expand Up @@ -178,5 +178,33 @@ def test_invalid_pytype_metadata(self):
self.assertErrorSequences(err, {"e": ["pytype_metadata"]})


class PYITestAll(test_base.BaseTest):
"""Tests for __all__."""

def test_star_import(self):
with self.DepTree([("foo.pyi", """
import datetime
__all__ = ['f', 'g']
def f(x): ...
def h(x): ...
"""), ("bar.pyi", """
from foo import *
""")]):
self.CheckWithErrors("""
import bar
a = bar.datetime # module-attr
b = bar.f(1)
c = bar.h(1) # module-attr
""")

def test_http_client(self):
"""Check that we can get unexported symbols from http.client."""
self.Check("""
from http import client
from six.moves import http_client
status = http_client.FOUND or client.FOUND
""")


if __name__ == "__main__":
test_base.main()

0 comments on commit 559970f

Please sign in to comment.