diff --git a/CHANGELOG b/CHANGELOG index 5e392efc1..8806f1c7c 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,19 @@ +Version 2021.11.02: + +New features and updates: +* Remove the --bind-properties flag. Its behavior has been made the default. +* Take advantage of module aliases to print prettier stub files. +* Add support for cross-module attr.s wrappers. +* Add a feature flag, --gen-stub-imports, to improve pyi import handling. +* Add a bit more support for PEP 612 in stubs. + +Bug fixes: +* Add remove{prefix,suffix} methods for bytes, bytearray. +* Fix a bug where Errorlog.copy_from() duplicated error details. +* Fix some issues with handling module aliases in stub files. +* Support a [not-supported-yet] case in a generic class TypeVar renaming check. +* Add `__init__` attributes to canonical enum members. + Version 2021.10.25: New features and updates: diff --git a/pytype/__version__.py b/pytype/__version__.py index 507ec1ce7..b1d5f26c0 100644 --- a/pytype/__version__.py +++ b/pytype/__version__.py @@ -1,2 +1,2 @@ # pylint: skip-file -__version__ = '2021.10.25' +__version__ = '2021.11.02' diff --git a/pytype/abstract/abstract.py b/pytype/abstract/abstract.py index ffdb3f024..632273dad 100644 --- a/pytype/abstract/abstract.py +++ b/pytype/abstract/abstract.py @@ -1557,23 +1557,36 @@ def _build_value(self, node, inner, ellipses): # For user-defined generic types, check if its type parameter matches # its corresponding concrete type if isinstance(base_cls, InterpreterClass) and base_cls.template: - for formal in base_cls.template: - if (isinstance(formal, TypeParameter) and not formal.is_generic() and - isinstance(params[formal.name], TypeParameter)): - if formal.name != params[formal.name].name: - self.ctx.errorlog.not_supported_yet( - self.ctx.vm.frames, - "Renaming TypeVar `%s` with constraints or bound" % formal.name) + for formal_param in base_cls.template: + root_node = self.ctx.root_node + param_value = params[formal_param.name] + if (isinstance(formal_param, TypeParameter) and + not formal_param.is_generic() and + isinstance(param_value, TypeParameter)): + if formal_param.name == param_value.name: + # We don't need to check if a TypeParameter matches itself. + continue + else: + actual = param_value.instantiate( + root_node, container=abstract_utils.DUMMY_CONTAINER) else: - root_node = self.ctx.root_node - actual = params[formal.name].instantiate(root_node) - bad = self.ctx.matcher(root_node).bad_matches(actual, formal) - if bad: - formal = self.ctx.annotation_utils.sub_one_annotation( - root_node, formal, [{}]) - self.ctx.errorlog.bad_concrete_type(self.ctx.vm.frames, root_node, - formal, actual, bad) - return self.ctx.convert.unsolvable + actual = param_value.instantiate(root_node) + bad = self.ctx.matcher(root_node).bad_matches(actual, formal_param) + if bad: + if not isinstance(param_value, TypeParameter): + # If param_value is not a TypeVar, we substitute in TypeVar bounds + # and constraints in formal_param for a more helpful error message. + formal_param = self.ctx.annotation_utils.sub_one_annotation( + root_node, formal_param, [{}]) + details = None + elif isinstance(formal_param, TypeParameter): + details = (f"TypeVars {formal_param.name} and {param_value.name} " + "have incompatible bounds or constraints.") + else: + details = None + self.ctx.errorlog.bad_concrete_type( + self.ctx.vm.frames, root_node, formal_param, actual, bad, details) + return self.ctx.convert.unsolvable try: return abstract_class(base_cls, params, self.ctx, template_params) diff --git a/pytype/analyze.py b/pytype/analyze.py index 7964eb8d8..1ac88dacc 100644 --- a/pytype/analyze.py +++ b/pytype/analyze.py @@ -106,7 +106,8 @@ def infer_types(src, if ctx.vm.has_unknown_wildcard_imports or any( a in defs for a in abstract_utils.DYNAMIC_ATTRIBUTE_MARKERS): if "__getattr__" not in ast: - ast = pytd_utils.Concat(ast, builtins.GetDefaultAst()) + ast = pytd_utils.Concat( + ast, builtins.GetDefaultAst(options.gen_stub_imports)) # If merged with other if statement, triggers a ValueError: Unresolved class # when attempts to load from the protocols file if options.protocols: diff --git a/pytype/config.py b/pytype/config.py index 66846a418..e0cac3674 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -172,6 +172,11 @@ def add_basic_options(o): help=( "Enable stricter namedtuple checks, such as unpacking and " "'typing.Tuple' compatibility. ") + temporary) + o.add_argument( + "--gen-stub-imports", action="store_true", + dest="gen_stub_imports", default=False, + help=("Generate import statements (`import x`) rather than constants " + "(`x: module`) for module names in stub files. ") + temporary) def add_subtools(o): diff --git a/pytype/convert.py b/pytype/convert.py index 7caa48b60..1c4444b8b 100644 --- a/pytype/convert.py +++ b/pytype/convert.py @@ -514,20 +514,25 @@ def _load_late_type_module(self, late_type): def _load_late_type(self, late_type): """Resolve a late type, possibly by loading a module.""" if late_type.name not in self._resolved_late_types: - ast, attr_name = self._load_late_type_module(late_type) - if ast is None: - log.error("During dependency resolution, couldn't resolve late type %r", - late_type.name) - t = pytd.AnythingType() + ast = self.ctx.loader.import_name(late_type.name) + if ast: + t = pytd.Module(name=late_type.name, module_name=late_type.name) else: - try: - cls = pytd.LookupItemRecursive(ast, attr_name) - except KeyError: - if "__getattr__" not in ast: - log.warning("Couldn't resolve %s", late_type.name) + ast, attr_name = self._load_late_type_module(late_type) + if ast is None: + log.error( + "During dependency resolution, couldn't resolve late type %r", + late_type.name) t = pytd.AnythingType() else: - t = pytd.ToType(cls, allow_functions=True) + try: + cls = pytd.LookupItemRecursive(ast, attr_name) + except KeyError: + if "__getattr__" not in ast: + log.warning("Couldn't resolve %s", late_type.name) + t = pytd.AnythingType() + else: + t = pytd.ToType(cls, allow_functions=True) self._resolved_late_types[late_type.name] = t return self._resolved_late_types[late_type.name] @@ -536,7 +541,10 @@ def _create_module(self, ast): raise abstract_utils.ModuleLoadError() data = (ast.constants + ast.type_params + ast.classes + ast.functions + ast.aliases) - members = {val.name.rsplit(".")[-1]: val for val in data} + members = {} + for val in data: + name = utils.strip_prefix(val.name, f"{ast.name}.") + members[name] = val return abstract.Module(self.ctx, ast.name, members, ast) def _get_literal_value(self, pyval): diff --git a/pytype/errors.py b/pytype/errors.py index 0ca3b6930..ad3f15c67 100644 --- a/pytype/errors.py +++ b/pytype/errors.py @@ -347,7 +347,7 @@ def __getitem__(self, index): def copy_from(self, errors, stack): for e in errors: with _CURRENT_ERROR_NAME.bind(e.name): - self.error(stack, e.message, e.details, e.keyword, e.bad_call, + self.error(stack, e._message, e.details, e.keyword, e.bad_call, # pylint: disable=protected-access e.keyword_context) def is_valid_error_name(self, name): @@ -923,14 +923,16 @@ def bad_yield_annotation(self, stack, name, annot, is_async): self.error(stack, message, details) @_error_name("bad-concrete-type") - def bad_concrete_type(self, stack, node, formal, actual, bad): + def bad_concrete_type(self, stack, node, formal, actual, bad, details=None): expected, actual, _, protocol_details, nis_details = ( self._print_as_return_types(node, formal, actual, bad)) - details = [" Expected: ", expected, "\n", - "Actually passed: ", actual] - details.extend(protocol_details + nis_details) + full_details = [" Expected: ", expected, "\n", + "Actually passed: ", actual] + if details: + full_details.append("\n" + details) + full_details.extend(protocol_details + nis_details) self.error( - stack, "Invalid instantiation of generic class", "".join(details)) + stack, "Invalid instantiation of generic class", "".join(full_details)) def _show_variable(self, var): """Show variable as 'name: typ' or 'pyval: typ' if available.""" diff --git a/pytype/io.py b/pytype/io.py index fd28adaa7..d2459349c 100644 --- a/pytype/io.py +++ b/pytype/io.py @@ -136,7 +136,7 @@ def check_or_generate_pyi(options, loader=None): errorlog = errors.ErrorLog() result = pytd_builtins.DEFAULT_SRC - ast = pytd_builtins.GetDefaultAst() + ast = pytd_builtins.GetDefaultAst(options.gen_stub_imports) try: src = read_source_file(options.input, options.open_function) if options.check: @@ -235,7 +235,7 @@ def write_pickle(ast, options, loader=None): if options.nofail: ast = serialize_ast.PrepareForExport( options.module_name, - pytd_builtins.GetDefaultAst(), loader) + pytd_builtins.GetDefaultAst(options.gen_stub_imports), loader) log.warning("***Caught exception: %s", str(e), exc_info=True) else: raise diff --git a/pytype/io_test.py b/pytype/io_test.py index 664d35019..27a6a18a3 100644 --- a/pytype/io_test.py +++ b/pytype/io_test.py @@ -80,9 +80,11 @@ def test_generate_pyi_with_options(self): with self._tmpfile( "{mod} {path}".format(mod=pyi_name, path=pyi.name)) as imports_map: src = "import {mod}; y = {mod}.x".format(mod=pyi_name) - options = config.Options.create(imports_map=imports_map.name) + options = config.Options.create(imports_map=imports_map.name, + gen_stub_imports=True) _, pyi_string, _ = io.generate_pyi(src, options) - self.assertEqual(pyi_string, "{mod}: module\ny: int\n".format(mod=pyi_name)) + self.assertEqual(pyi_string, + "import {mod}\n\ny: int\n".format(mod=pyi_name)) def test_check_or_generate_pyi__check(self): with self._tmpfile("") as f: diff --git a/pytype/load_pytd.py b/pytype/load_pytd.py index fe13a2fbc..0749c522f 100644 --- a/pytype/load_pytd.py +++ b/pytype/load_pytd.py @@ -27,6 +27,7 @@ "python_version": "python_version", "pythonpath": "pythonpath", "use_typeshed": "typeshed", + "gen_stub_imports": "gen_stub_imports", } @@ -129,8 +130,9 @@ class _ModuleMap: PREFIX = "pytd:" # for pytd files that ship with pytype - def __init__(self, python_version, modules=None): + def __init__(self, python_version, modules, gen_stub_imports): self.python_version = python_version + self.gen_stub_imports = gen_stub_imports self._modules: Dict[str, Module] = modules or self._base_modules() if self._modules["builtins"].needs_unpickling(): self._unpickle_module(self._modules["builtins"]) @@ -186,7 +188,7 @@ def get_resolved_modules(self) -> Dict[str, ResolvedModule]: return resolved_modules def _base_modules(self): - bltins, typing = builtins.GetBuiltinsAndTyping() + bltins, typing = builtins.GetBuiltinsAndTyping(self.gen_stub_imports) return { "builtins": Module("builtins", self.PREFIX + "builtins", bltins, @@ -371,8 +373,9 @@ def collect_dependencies(cls, mod_ast): class _BuiltinLoader: """Load builtins from the pytype source tree.""" - def __init__(self, python_version): + def __init__(self, python_version, gen_stub_imports): self.python_version = python_version + self.gen_stub_imports = gen_stub_imports def _parse_predefined(self, pytd_subdir, module, as_package=False): """Parse a pyi/pytd file in the pytype source tree.""" @@ -382,7 +385,8 @@ def _parse_predefined(self, pytd_subdir, module, as_package=False): except IOError: return None ast = parser.parse_string(src, filename=filename, name=module, - python_version=self.python_version) + python_version=self.python_version, + gen_stub_imports=self.gen_stub_imports) assert ast.name == module return ast @@ -415,6 +419,7 @@ class Loader: imports_map: A short_path -> full_name mapping for imports. use_typeshed: Whether to use https://github.com/python/typeshed. open_function: A custom file opening function. + gen_stub_imports: Temporary flag for releasing --gen-stub-imports. """ PREFIX = "pytd:" # for pytd files that ship with pytype @@ -426,20 +431,22 @@ def __init__(self, imports_map=None, use_typeshed=True, modules=None, - open_function=open): + open_function=open, + gen_stub_imports=True): self.python_version = utils.normalize_version(python_version) - self._modules = _ModuleMap(self.python_version, modules) + self._modules = _ModuleMap(self.python_version, modules, gen_stub_imports) self.builtins = self._modules["builtins"].ast self.typing = self._modules["typing"].ast self.base_module = base_module self._path_finder = _PathFinder(imports_map, pythonpath) - self._builtin_loader = _BuiltinLoader(self.python_version) + self._builtin_loader = _BuiltinLoader(self.python_version, gen_stub_imports) self._resolver = _Resolver(self.builtins) self.use_typeshed = use_typeshed self.open_function = open_function self._import_name_cache = {} # performance cache self._aliases = {} self._prefixes = set() + self.gen_stub_imports = gen_stub_imports # Paranoid verification that pytype.main properly checked the flags: if imports_map is not None: assert pythonpath == [""], pythonpath @@ -505,7 +512,8 @@ def load_file(self, module_name, filename, mod_ast=None): with self.open_function(filename, "r") as f: mod_ast = parser.parse_string( f.read(), filename=filename, name=module_name, - python_version=self.python_version) + python_version=self.python_version, + gen_stub_imports=self.gen_stub_imports) return self._process_module(module_name, filename, mod_ast) def _process_module(self, module_name, filename, mod_ast): @@ -721,7 +729,7 @@ def _load_builtin(self, subdir, module_name, third_party_only=False): def _load_typeshed_builtin(self, subdir, module_name): """Load a pyi from typeshed.""" loaded = typeshed.parse_type_definition( - subdir, module_name, self.python_version) + subdir, module_name, self.python_version, self.gen_stub_imports) if loaded: filename, mod_ast = loaded return self.load_file(filename=self.PREFIX + filename, diff --git a/pytype/matcher.py b/pytype/matcher.py index 476149f77..6df7c68b5 100644 --- a/pytype/matcher.py +++ b/pytype/matcher.py @@ -292,7 +292,8 @@ def _match_value_against_type(self, value, other_type, subst, view): if isinstance(left, abstract.TypeParameterInstance) and ( isinstance(left.instance, (abstract.CallableClass, - function.Signature))): + function.Signature)) or + left.instance is abstract_utils.DUMMY_CONTAINER): if isinstance(other_type, abstract.TypeParameter): new_subst = self._match_type_param_against_type_param( left.param, other_type, subst, view) diff --git a/pytype/output.py b/pytype/output.py index c801bd23d..2a87dd4ad 100644 --- a/pytype/output.py +++ b/pytype/output.py @@ -218,7 +218,8 @@ def value_to_pytd_type(self, node, v, seen, view): if isinstance(v, (abstract.Empty, typing_overlay.NoReturn)): return pytd.NothingType() elif isinstance(v, abstract.TypeParameterInstance): - if v.module in self._scopes: + if (v.module in self._scopes or + v.instance is abstract_utils.DUMMY_CONTAINER): return self._typeparam_to_def(node, v.param, v.param.name) elif v.instance.get_instance_type_parameter(v.full_name).bindings: # The type parameter was initialized. Set the view to None, since we @@ -271,7 +272,10 @@ def value_to_pytd_type(self, node, v, seen, view): return pytd.GenericType(base_type=pytd.NamedType("builtins.type"), parameters=(param,)) elif isinstance(v, abstract.Module): - return pytd.NamedType("builtins.module") + if self.ctx.options.gen_stub_imports: + return pytd.Alias(v.name, pytd.Module(v.name, module_name=v.full_name)) + else: + return pytd.NamedType("builtins.module") elif (self._output_mode >= Converter.OutputMode.LITERAL and isinstance(v, abstract.ConcreteValue) and isinstance(v.pyval, (int, str, bytes))): @@ -354,7 +358,9 @@ def value_to_pytd_def(self, node, v, name): Returns: A PyTD definition. """ - if isinstance(v, abstract.BoundFunction): + if self.ctx.options.gen_stub_imports and isinstance(v, abstract.Module): + return pytd.Alias(name, pytd.Module(name, module_name=v.full_name)) + elif isinstance(v, abstract.BoundFunction): d = self.value_to_pytd_def(node, v.underlying, name) assert isinstance(d, pytd.Function) sigs = tuple(sig.Replace(params=sig.params[1:]) for sig in d.signatures) diff --git a/pytype/overlays/collections_overlay.py b/pytype/overlays/collections_overlay.py index 810402154..616f19b41 100644 --- a/pytype/overlays/collections_overlay.py +++ b/pytype/overlays/collections_overlay.py @@ -22,7 +22,8 @@ def namedtuple_ast(name, fields, defaults, python_version=None, - strict_namedtuple_checks=True): + strict_namedtuple_checks=True, + gen_stub_imports=True): """Make an AST with a namedtuple definition for the given name and fields. Args: @@ -33,7 +34,8 @@ def namedtuple_ast(name, strict_namedtuple_checks: Whether to enable a stricter type annotation hierarchy for generated NamedType. e.g. Tuple[n*[Any]] instead of tuple. This should usually be set to the value of - ctx.options.strict_namedtuple_checks + ctx.options.strict_namedtuple_checks. + gen_stub_imports: Set this to the value of ctx.options.gen_stub_imports. Returns: A pytd.TypeDeclUnit with the namedtuple definition in its classes. @@ -79,7 +81,8 @@ def _replace(self: {typevar}, **kwds) -> {typevar}: ... repeat_any=_repeat_type("typing.Any", num_fields), fields_as_parameters=fields_as_parameters, field_names_as_strings=field_names_as_strings) - return parser.parse_string(nt, python_version=python_version) + return parser.parse_string(nt, python_version=python_version, + gen_stub_imports=gen_stub_imports) class CollectionsOverlay(overlay.Overlay): @@ -303,7 +306,8 @@ class have to be changed to match the number and names of the fields, we field_names, defaults, python_version=self.ctx.python_version, - strict_namedtuple_checks=self.ctx.options.strict_namedtuple_checks) + strict_namedtuple_checks=self.ctx.options.strict_namedtuple_checks, + gen_stub_imports=self.ctx.options.gen_stub_imports) mapping = self._get_known_types_mapping() # A truly well-formed pyi for the namedtuple will have references to the new diff --git a/pytype/overlays/enum_overlay.py b/pytype/overlays/enum_overlay.py index ee48e5049..5a44b2b47 100644 --- a/pytype/overlays/enum_overlay.py +++ b/pytype/overlays/enum_overlay.py @@ -22,6 +22,8 @@ into a proper enum. """ +import collections +import contextlib import logging from pytype import overlay @@ -195,14 +197,24 @@ class EnumInstance(abstract.InterpreterClass): def __init__(self, name, bases, members, cls, ctx): super().__init__(name, bases, members, cls, ctx) - # This is set by EnumMetaInit.setup_interpreterclass. + # These are set by EnumMetaInit.setup_interpreterclass. self.member_type = None + self.member_attrs = {} + self._instantiating = False + + @contextlib.contextmanager + def _is_instantiating(self): + old_instantiating = self._instantiating + self._instantiating = True + try: + yield + finally: + self._instantiating = old_instantiating def instantiate(self, node, container=None): # Instantiate creates a canonical enum member. This intended for when no # particular enum member is needed, e.g. during analysis. Real members have # these fields set during class creation. - # TODO(tsudol): Use the types of other members to set `value`. del container instance = abstract.Instance(self, self.ctx) instance.members["name"] = self.ctx.convert.build_nonatomic_string(node) @@ -214,6 +226,14 @@ def instantiate(self, node, container=None): # But there's no reason not to make sure this function is safe. value = self.ctx.new_unsolvable(node) instance.members["value"] = value + for attr_name, attr_type in self.member_attrs.items(): + # attr_type might refer back to self, so track whether we are + # instantiating to avoid infinite recursion. + if self._instantiating: + instance.members[attr_name] = self.ctx.new_unsolvable(node) + else: + with self._is_instantiating(): + instance.members[attr_name] = attr_type.instantiate(node) return instance.to_variable(node) def is_empty_enum(self): @@ -451,6 +471,7 @@ def _mark_dynamic_enum(self, cls): def _setup_interpreterclass(self, node, cls): member_types = [] + member_attrs = collections.defaultdict(list) base_type = self._get_base_type(cls.bases()) # Enum members are created by calling __new__ (of either the base type or # the first enum in MRO that defines its own __new__, or else object if @@ -501,6 +522,10 @@ def _setup_interpreterclass(self, node, cls): node = cls.call_init(node, cls.to_binding(node), init_args) member.members["value"] = member.members["_value_"] member.members["name"] = self.ctx.convert.build_string(node, name) + for attr_name in member.members: + if attr_name in ("name", "value"): + continue + member_attrs[attr_name].extend(member.members[attr_name].data) cls.members[name] = member.to_variable(node) member_types.extend(value.data) if base_type: @@ -509,7 +534,11 @@ def _setup_interpreterclass(self, node, cls): member_type = self.ctx.convert.merge_classes(member_types) else: member_type = self.ctx.convert.unsolvable + member_attrs = { + n: self.ctx.convert.merge_classes(ts) for n, ts in member_attrs.items() + } cls.member_type = member_type + cls.member_attrs = member_attrs # If cls has a __new__, save it for later. (See _get_member_new above.) # It needs to be marked as a classmethod, or else pytype will try to # pass an instance of cls instead of cls when analyzing it. diff --git a/pytype/overlays/typing_overlay.py b/pytype/overlays/typing_overlay.py index 3c463ce54..732609b34 100644 --- a/pytype/overlays/typing_overlay.py +++ b/pytype/overlays/typing_overlay.py @@ -1,10 +1,5 @@ """Implementation of the types in Python 3's typing.py.""" -# We should be able to enable pytype on this file once we switch to the -# typed_ast-based pyi parser. The current parser can't handle a constant called -# 'namedtuple': -# pytype: skip-file - # pylint's detection of this is error-prone: # pylint: disable=unpacking-non-sequence diff --git a/pytype/pyi/CMakeLists.txt b/pytype/pyi/CMakeLists.txt index e202aa6bb..1f296a1b1 100644 --- a/pytype/pyi/CMakeLists.txt +++ b/pytype/pyi/CMakeLists.txt @@ -147,6 +147,18 @@ toplevel_py_binary( pytype.pytd.pytd_for_parser ) +py_library( + NAME + parser_test_base + SRCS + parser_test_base.py + DEPS + ._ast_parser + pytype.utils + pytype.pytd.pytd + pytype.tests.test_base +) + py_test( NAME parser_test @@ -154,12 +166,22 @@ py_test( parser_test.py DEPS .parser - pytype.utils + .parser_test_base pytype.pytd.pytd_for_parser pytype.stubs.stubs pytype.tests.test_base ) +py_test( + NAME + entire_file_parser_test + SRCS + entire_file_parser_test.py + DEPS + .parser_test_base + pytype.pytd.pytd +) + py_test( NAME evaluator_test diff --git a/pytype/pyi/definitions.py b/pytype/pyi/definitions.py index 9fc5ad537..8667c9c46 100644 --- a/pytype/pyi/definitions.py +++ b/pytype/pyi/definitions.py @@ -450,7 +450,11 @@ def add_import(self, from_package, import_list): from_package != "typing" or self.module_info.module_name == "protocols"): self.aliases[t.new_name] = t.pytd_alias() - self.module_path_map[t.new_name] = t.qualified_name + if t.new_name != "typing": + # We don't allow the typing module to be mapped to another module, + # since that would lead to 'from typing import ...' statements to be + # resolved incorrectly. + self.module_path_map[t.new_name] = t.qualified_name else: # import a, b as c, ... for item in import_list: @@ -507,23 +511,31 @@ def _parameterized_type(self, base_type: Any, parameters): raise ParseError("[..., ...] not supported") return pytd.GenericType(base_type=base_type, parameters=(element_type,)) else: - parameters = tuple(pytd.AnythingType() if p is self.ELLIPSIS else p - for p in parameters) + processed_parameters = [] + # We do not yet support PEP 612, Parameter Specification Variables. + # To avoid blocking typeshed from adopting this PEP, we convert new + # features to approximations that only use supported features. + for p in parameters: + if p is self.ELLIPSIS: + processed = pytd.AnythingType() + elif (p in self.param_specs and + self._matches_full_name(base_type, "typing.Generic")): + # Replacing a ParamSpec with a TypeVar isn't correct, but it'll work + # for simple cases in which the filled value is also a ParamSpec. + self.type_params.append(pytd.TypeParameter(p.name)) + processed = p + elif (p in self.param_specs or + (isinstance(p, pytd.GenericType) and + self._matches_full_name(p, _CONCATENATE_TYPES))): + processed = pytd.AnythingType() + else: + processed = p + processed_parameters.append(processed) + parameters = tuple(processed_parameters) if self._matches_named_type(base_type, _TUPLE_TYPES): return pytdgen.heterogeneous_tuple(base_type, parameters) elif self._matches_named_type(base_type, _CALLABLE_TYPES): - callable_parameters = [] - for p in parameters: - # We do not yet support PEP 612, Parameter Specification Variables. - # To avoid blocking typeshed from adopting this PEP, we convert new - # features to Any. - if p in self.param_specs or ( - isinstance(p, pytd.GenericType) and - self._matches_full_name(p, _CONCATENATE_TYPES)): - callable_parameters.append(pytd.AnythingType()) - else: - callable_parameters.append(p) - return pytdgen.pytd_callable(base_type, tuple(callable_parameters)) + return pytdgen.pytd_callable(base_type, parameters) else: assert parameters return pytd.GenericType(base_type=base_type, parameters=parameters) @@ -554,7 +566,7 @@ def resolve_type(self, name: Union[str, pytd_node.Node]) -> pytd.Type: def new_type( self, name: Union[str, pytd_node.Node], - parameters: Optional[List[pytd_node.Node]] = None + parameters: Optional[List[pytd.Type]] = None ) -> pytd.Type: """Return the AST for a type. diff --git a/pytype/pyi/entire_file_parser_test.py b/pytype/pyi/entire_file_parser_test.py new file mode 100644 index 000000000..05fee323f --- /dev/null +++ b/pytype/pyi/entire_file_parser_test.py @@ -0,0 +1,17 @@ +"""Entire-file parsing test.""" + +from pytype.pyi import parser_test_base +from pytype.pytd import pytd_utils + +import unittest + + +class EntireFileTest(parser_test_base.ParserTestBase): + + def test_builtins(self): + _, builtins = pytd_utils.GetPredefinedFile("builtins", "builtins") + self.check(builtins, expected=parser_test_base.IGNORE) + + +if __name__ == "__main__": + unittest.main() diff --git a/pytype/pyi/modules.py b/pytype/pyi/modules.py index b2c22a705..75185cd72 100644 --- a/pytype/pyi/modules.py +++ b/pytype/pyi/modules.py @@ -7,7 +7,6 @@ from pytype import module_utils from pytype.pyi.types import ParseError # pylint: disable=g-importing-member from pytype.pytd import pytd -from pytype.pytd import visitors from pytype.pytd.parse import parser_constants # pylint: disable=g-importing-member @@ -27,12 +26,13 @@ def pytd_alias(self): class Module: """Module and package details.""" - def __init__(self, filename, module_name): + def __init__(self, filename, module_name, gen_stub_imports): self.filename = filename self.module_name = module_name is_package = file_utils.is_pyi_directory_init(filename) self.package_name = module_utils.get_package_name(module_name, is_package) self.parent_name = module_utils.get_package_name(self.package_name, False) + self.gen_stub_imports = gen_stub_imports def _qualify_name_with_special_dir(self, orig_name): """Handle the case of '.' and '..' as package names.""" @@ -58,8 +58,6 @@ def _qualify_name_with_special_dir(self, orig_name): def qualify_name(self, orig_name): """Qualify an import name.""" - # Doing the "builtins" rename here ensures that we catch alias names. - orig_name = visitors.RenameBuiltinsPrefixInName(orig_name) if not self.package_name: return orig_name rel_name = self._qualify_name_with_special_dir(orig_name) @@ -75,10 +73,17 @@ def qualify_name(self, orig_name): def process_import(self, item): """Process 'import a, b as c, ...'.""" - if not isinstance(item, tuple): + if isinstance(item, tuple): + name, new_name = item + elif self.gen_stub_imports: + name = new_name = item + else: # We don't care about imports that are not aliased. return None - name, new_name = item + if name == new_name == "__builtin__": + # 'import __builtin__' should be completely ignored; this is the PY2 name + # of the builtins module. + return None module_name = self.qualify_name(name) as_name = self.qualify_name(new_name) t = pytd.Module(name=as_name, module_name=module_name) diff --git a/pytype/pyi/parser.py b/pytype/pyi/parser.py index 66cdbf5da..2b902ffa7 100644 --- a/pytype/pyi/parser.py +++ b/pytype/pyi/parser.py @@ -267,8 +267,10 @@ def __repr__(self): class GeneratePytdVisitor(visitor.BaseVisitor): """Converts a typed_ast tree to a pytd tree.""" - def __init__(self, src, filename, module_name, version, platform): - defs = definitions.Definitions(modules.Module(filename, module_name)) + def __init__(self, src, filename, module_name, version, platform, + gen_stub_imports): + defs = definitions.Definitions( + modules.Module(filename, module_name, gen_stub_imports)) super().__init__(defs=defs, filename=filename) self.src_code = src self.module_name = module_name @@ -685,10 +687,12 @@ def parse_string( python_version: VersionType = 3, name: Optional[str] = None, filename: Optional[str] = None, - platform: Optional[str] = None + platform: Optional[str] = None, + gen_stub_imports: bool = True, ): return parse_pyi(src, filename=filename, module_name=name, - platform=platform, python_version=python_version) + platform=platform, python_version=python_version, + gen_stub_imports=gen_stub_imports) def parse_pyi( @@ -696,7 +700,8 @@ def parse_pyi( filename: Optional[str], module_name: str, python_version: VersionType = 3, - platform: Optional[str] = None + platform: Optional[str] = None, + gen_stub_imports: bool = True, ) -> pytd.TypeDeclUnit: """Parse a pyi string.""" filename = filename or "" @@ -704,7 +709,7 @@ def parse_pyi( python_version = utils.normalize_version(python_version) root = _parse(src, feature_version, filename) gen_pytd = GeneratePytdVisitor( - src, filename, module_name, python_version, platform) + src, filename, module_name, python_version, platform, gen_stub_imports) root = gen_pytd.visit(root) root = post_process_ast(root, src, module_name) return root @@ -715,7 +720,7 @@ def parse_pyi_debug( filename: str, module_name: str, python_version: VersionType = 3, - platform: Optional[str] = None + platform: Optional[str] = None, ) -> Tuple[pytd.TypeDeclUnit, GeneratePytdVisitor]: """Debug version of parse_pyi.""" feature_version = _feature_version(python_version) @@ -723,7 +728,7 @@ def parse_pyi_debug( root = _parse(src, feature_version, filename) print(debug.dump(root, ast3, include_attributes=False)) gen_pytd = GeneratePytdVisitor( - src, filename, module_name, python_version, platform) + src, filename, module_name, python_version, platform, True) root = gen_pytd.visit(root) print("---transformed parse tree--------------------") print(root) @@ -736,9 +741,11 @@ def parse_pyi_debug( return root, gen_pytd -def canonical_pyi(pyi, python_version=3, multiline_args=False): +def canonical_pyi(pyi, python_version=3, multiline_args=False, + gen_stub_imports=True): """Rewrite a pyi in canonical form.""" - ast = parse_string(pyi, python_version=python_version) + ast = parse_string(pyi, python_version=python_version, + gen_stub_imports=gen_stub_imports) ast = ast.Visit(visitors.ClassTypeToNamedType()) ast = ast.Visit(visitors.CanonicalOrderingVisitor(sort_signatures=True)) ast.Visit(visitors.VerifyVisitor()) diff --git a/pytype/pyi/parser_test.py b/pytype/pyi/parser_test.py index 6bcd44266..9baa0dc65 100644 --- a/pytype/pyi/parser_test.py +++ b/pytype/pyi/parser_test.py @@ -1,67 +1,14 @@ import hashlib -import re import sys import textwrap -from pytype import utils from pytype.pyi import parser +from pytype.pyi import parser_test_base from pytype.pytd import pytd -from pytype.pytd import pytd_utils from pytype.tests import test_base import unittest -IGNORE = object() - - -class _ParserTestBase(test_base.UnitTest): - - def check(self, src, expected=None, prologue=None, name=None, - version=None, platform=None): - """Check the parsing of src. - - This checks that parsing the source and then printing the resulting - AST results in the expected text. - - Args: - src: A source string. - expected: Optional expected result string. If not provided, src is - used instead. The special value IGNORE can be used to skip - checking the parsed results against expected text. - prologue: An optional prologue to be prepended to the expected text - before comparisson. Useful for imports that are introduced during - printing the AST. - name: The name of the module. - version: A python version tuple (None for default value). - platform: A platform string (None for default value). - - Returns: - The parsed pytd.TypeDeclUnit. - """ - version = version or self.python_version - src = textwrap.dedent(src).lstrip() - ast = parser.parse_string(src, name=name, python_version=version, - platform=platform) - actual = pytd_utils.Print(ast) - if expected != IGNORE: - if expected is None: - expected = src - else: - expected = textwrap.dedent(expected).lstrip() - if prologue: - expected = "%s\n\n%s" % (textwrap.dedent(prologue), expected) - # Allow blank lines at the end of `expected` for prettier tests. - self.assertMultiLineEqual(expected.rstrip(), actual) - return ast - - def check_error(self, src, expected_line, message): - """Check that parsing the src raises the expected error.""" - with self.assertRaises(parser.ParseError) as e: - parser.parse_string(textwrap.dedent(src).lstrip(), - python_version=self.python_version) - self.assertRegex(utils.message(e.exception), re.escape(message)) - self.assertEqual(expected_line, e.exception.line) - class ParseErrorTest(unittest.TestCase): @@ -107,7 +54,7 @@ def test_column_without_text(self): self.check(" ParseError: my message", "my message", column=5) -class ParserTest(_ParserTestBase): +class ParserTest(parser_test_base.ParserTestBase): def test_syntax_error(self): self.check_error("123", 1, "Unexpected expression") @@ -312,7 +259,7 @@ class B: ... """, "") def test_import(self): - self.check("import foo.bar.baz", "") + self.check("import foo.bar.baz") self.check("import a as b") self.check("from foo.bar import baz") self.check("from foo.bar import baz as abc") @@ -327,7 +274,8 @@ def test_import(self): "from foo import a\nfrom foo import b") def test_from_import(self): - ast = self.check("from foo import c\nclass Bar(c.X): ...", IGNORE) + ast = self.check("from foo import c\nclass Bar(c.X): ...", + parser_test_base.IGNORE) parent, = ast.Lookup("Bar").parents self.assertEqual(parent, pytd.NamedType("foo.c.X")) @@ -432,6 +380,7 @@ def test_same_named_alias(self): class Bar: Foo = somewhere.Foo """, """ + import somewhere from typing import Any class Bar: @@ -490,6 +439,7 @@ def test_typing_typevar(self): import typing T = typing.TypeVar('T') """, """ + import typing from typing import TypeVar T = TypeVar('T') @@ -573,7 +523,7 @@ def test_all(self): """) -class QuotedTypeTest(_ParserTestBase): +class QuotedTypeTest(parser_test_base.ParserTestBase): def test_annotation(self): self.check(""" @@ -598,7 +548,7 @@ def test_subscript(self): self.check_error("x: List['int']", 1, "List['int'] not supported") -class HomogeneousTypeTest(_ParserTestBase): +class HomogeneousTypeTest(parser_test_base.ParserTestBase): def test_callable_parameters(self): self.check(""" @@ -704,7 +654,7 @@ def test_type_tuple(self): "x: tuple") -class NamedTupleTest(_ParserTestBase): +class NamedTupleTest(parser_test_base.ParserTestBase): @unittest.skip("Constructors in type annotations not supported") def test_no_fields(self): @@ -1037,7 +987,7 @@ class NamedTuple: ... """) -class FunctionTest(_ParserTestBase): +class FunctionTest(parser_test_base.ParserTestBase): def test_params(self): self.check("def foo() -> int: ...") @@ -1407,7 +1357,7 @@ def test_async(self): prologue="from typing import Any, Coroutine") -class ClassTest(_ParserTestBase): +class ClassTest(parser_test_base.ParserTestBase): def test_no_parents(self): canonical = """ @@ -1626,7 +1576,7 @@ class Mapping: ... self.assertEqual(x.type.name, "typing.Mapping") -class IfTest(_ParserTestBase): +class IfTest(parser_test_base.ParserTestBase): def test_if_true(self): self.check(""" @@ -1802,7 +1752,7 @@ def test_conditional_typevar(self): T = TypeVar('T')""") -class ClassIfTest(_ParserTestBase): +class ClassIfTest(parser_test_base.ParserTestBase): # These tests assume that IfTest has already covered the inner workings of # peer's functions. Instead, they focus on verifying that if statements @@ -1873,7 +1823,7 @@ class Foo: """, 3, r"TypeVars need to be defined at module level") -class ConditionTest(_ParserTestBase): +class ConditionTest(parser_test_base.ParserTestBase): def check_cond(self, condition, expected, **kwargs): out = "x: int" if expected else "" @@ -2010,7 +1960,7 @@ def test_unsupported_condition(self): "Unsupported condition: 'foo.bar'") -class PropertyDecoratorTest(_ParserTestBase): +class PropertyDecoratorTest(parser_test_base.ParserTestBase): """Tests that cover _parse_signature_as_property().""" def test_property_with_type(self): @@ -2137,7 +2087,7 @@ def name(self) -> int: ... """, 1, "Invalid property decorators for method `name`") -class MergeSignaturesTest(_ParserTestBase): +class MergeSignaturesTest(parser_test_base.ParserTestBase): def test_property(self): self.check(""" @@ -2258,14 +2208,7 @@ def foo(x: int, y: int) -> str: ... "abstractmethod decorators") -class EntireFileTest(_ParserTestBase): - - def test_builtins(self): - _, builtins = pytd_utils.GetPredefinedFile("builtins", "builtins") - self.check(builtins, expected=IGNORE) - - -class AnyTest(_ParserTestBase): +class AnyTest(parser_test_base.ParserTestBase): def test_generic_any(self): self.check(""" @@ -2291,7 +2234,7 @@ def test_generic_any_alias(self): x: Any""") -class CanonicalPyiTest(_ParserTestBase): +class CanonicalPyiTest(parser_test_base.ParserTestBase): def test_canonical_version(self): src = textwrap.dedent(""" @@ -2311,7 +2254,7 @@ def foo(x: str) -> Any: ... parser.canonical_pyi(src, self.python_version), expected) -class TypeMacroTest(_ParserTestBase): +class TypeMacroTest(parser_test_base.ParserTestBase): def test_simple(self): self.check(""" @@ -2427,7 +2370,7 @@ def f(x: List[str]) -> None: ... """) -class ImportTypeIgnoreTest(_ParserTestBase): +class ImportTypeIgnoreTest(parser_test_base.ParserTestBase): def test_import(self): self.check(""" @@ -2466,7 +2409,7 @@ def f(x: attr) -> None: ... self.assertTrue(ast.Lookup("f")) -class LiteralTest(_ParserTestBase): +class LiteralTest(parser_test_base.ParserTestBase): def test_bool(self): self.check(""" @@ -2603,7 +2546,7 @@ def test_bad_value(self): """, 2, "Invalid type `float` in Literal[0.0].") -class TypedDictTest(_ParserTestBase): +class TypedDictTest(parser_test_base.ParserTestBase): def test_assign(self): self.check(""" @@ -2690,7 +2633,7 @@ class Foo(TypedDict, metaclass=Meta): ... """) -class NewTypeTest(_ParserTestBase): +class NewTypeTest(parser_test_base.ParserTestBase): def test_basic(self): self.check(""" @@ -2708,6 +2651,8 @@ def test_fullname(self): import typing X = typing.NewType('X', int) """, """ + import typing + X = newtype_X_0 class newtype_X_0(int): @@ -2715,7 +2660,7 @@ def __init__(self, val: int) -> None: ... """) -class MethodAliasTest(_ParserTestBase): +class MethodAliasTest(parser_test_base.ParserTestBase): def test_normal_method(self): self.check(""" @@ -2787,7 +2732,7 @@ def f(x: int) -> None: ... """) -class AnnotatedTest(_ParserTestBase): +class AnnotatedTest(parser_test_base.ParserTestBase): """Test typing.Annotated.""" def test_annotated(self): @@ -2870,7 +2815,7 @@ def test_feature_version(self): self.assertEqual(actual, expected) -class ParamSpecTest(_ParserTestBase): +class ParamSpecTest(parser_test_base.ParserTestBase): def test_from_typing(self): self.check(""" @@ -2906,7 +2851,6 @@ def f(x: Callable[P, R]) -> Callable[P, Awaitable[R]]: ... def f(x: Callable[..., R]) -> Callable[..., Awaitable[R]]: ... """) - @test_base.skip("ParamSpec in custom generic classes not supported yet") def test_custom_generic(self): self.check(""" from typing import Callable, Generic, ParamSpec, TypeVar @@ -2917,6 +2861,38 @@ def test_custom_generic(self): class X(Generic[T, P]): f: Callable[P, int] x: T + """, """ + from typing import Callable, Generic, TypeVar + + P = TypeVar('P') + T = TypeVar('T') + + class X(Generic[T, P]): + f: Callable[..., int] + x: T + """) + + def test_use_custom_generic(self): + self.check(""" + from typing import Callable, Generic, TypeVar + from typing_extensions import ParamSpec + + _T = TypeVar('_T') + _P = ParamSpec('_P') + + class Foo(Generic[_P, _T]): ... + + def f(x: Callable[_P, _T]) -> Foo[_P, _T]: ... + """, """ + from typing import Any, Callable, Generic, TypeVar + from typing_extensions import ParamSpec + + _P = TypeVar('_P') + _T = TypeVar('_T') + + class Foo(Generic[_P, _T]): ... + + def f(x: Callable[..., _T]) -> Foo[Any, _T]: ... """) @test_base.skip("ParamSpec in custom generic classes not supported yet") @@ -2958,7 +2934,7 @@ def f(x: Callable[..., T], *args, **kwargs) -> T: ... """) -class ConcatenateTest(_ParserTestBase): +class ConcatenateTest(parser_test_base.ParserTestBase): def test_from_typing(self): self.check(""" @@ -3004,7 +2980,7 @@ def f(x: Callable[..., R]) -> Callable[..., R]: ... """) -class UnionOrTest(_ParserTestBase): +class UnionOrTest(parser_test_base.ParserTestBase): def test_basic(self): self.check(""" @@ -3020,7 +2996,7 @@ def h(x: Optional[str]) -> None: ... """) -class TypeGuardTest(_ParserTestBase): +class TypeGuardTest(parser_test_base.ParserTestBase): def test_typing_extensions(self): self.check(""" diff --git a/pytype/pyi/parser_test_base.py b/pytype/pyi/parser_test_base.py new file mode 100644 index 000000000..41f2b08be --- /dev/null +++ b/pytype/pyi/parser_test_base.py @@ -0,0 +1,61 @@ +"""Base code for pyi parsing tests.""" + +import re +import textwrap + +from pytype import utils +from pytype.pyi import parser +from pytype.pytd import pytd_utils +from pytype.tests import test_base + +IGNORE = object() + + +class ParserTestBase(test_base.UnitTest): + """Base class for pyi parsing tests.""" + + def check(self, src, expected=None, prologue=None, name=None, + version=None, platform=None): + """Check the parsing of src. + + This checks that parsing the source and then printing the resulting + AST results in the expected text. + + Args: + src: A source string. + expected: Optional expected result string. If not provided, src is + used instead. The special value IGNORE can be used to skip + checking the parsed results against expected text. + prologue: An optional prologue to be prepended to the expected text + before comparisson. Useful for imports that are introduced during + printing the AST. + name: The name of the module. + version: A python version tuple (None for default value). + platform: A platform string (None for default value). + + Returns: + The parsed pytd.TypeDeclUnit. + """ + version = version or self.python_version + src = textwrap.dedent(src).lstrip() + ast = parser.parse_string(src, name=name, python_version=version, + platform=platform) + actual = pytd_utils.Print(ast) + if expected != IGNORE: + if expected is None: + expected = src + else: + expected = textwrap.dedent(expected).lstrip() + if prologue: + expected = "%s\n\n%s" % (textwrap.dedent(prologue), expected) + # Allow blank lines at the end of `expected` for prettier tests. + self.assertMultiLineEqual(expected.rstrip(), actual) + return ast + + def check_error(self, src, expected_line, message): + """Check that parsing the src raises the expected error.""" + with self.assertRaises(parser.ParseError) as e: + parser.parse_string(textwrap.dedent(src).lstrip(), + python_version=self.python_version) + self.assertRegex(utils.message(e.exception), re.escape(message)) + self.assertEqual(expected_line, e.exception.line) diff --git a/pytype/pytd/CMakeLists.txt b/pytype/pytd/CMakeLists.txt index 458acdc8e..739eb4b15 100644 --- a/pytype/pytd/CMakeLists.txt +++ b/pytype/pytd/CMakeLists.txt @@ -203,6 +203,7 @@ py_library( SRCS serialize_ast.py DEPS + ._pytd .pytd_utils .visitors pytype.utils diff --git a/pytype/pytd/builtins.py b/pytype/pytd/builtins.py index f01f932e0..6d6c38529 100644 --- a/pytype/pytd/builtins.py +++ b/pytype/pytd/builtins.py @@ -22,11 +22,14 @@ def InvalidateCache(): del _cached_builtins_pytd[0] -def GetBuiltinsAndTyping(): # Deprecated. Use load_pytd instead. +# Deprecated. Use load_pytd instead. +def GetBuiltinsAndTyping(gen_stub_imports): """Get builtins.pytd and typing.pytd.""" if not _cached_builtins_pytd: - t = parser.parse_string(_FindBuiltinFile("typing"), name="typing") - b = parser.parse_string(_FindBuiltinFile("builtins"), name="builtins") + t = parser.parse_string(_FindBuiltinFile("typing"), name="typing", + gen_stub_imports=gen_stub_imports) + b = parser.parse_string(_FindBuiltinFile("builtins"), name="builtins", + gen_stub_imports=gen_stub_imports) b = b.Visit(visitors.LookupExternalTypes({"typing": t}, self_name="builtins")) t = t.Visit(visitors.LookupBuiltins(b)) @@ -58,7 +61,7 @@ def GetBuiltinsPyTD(): # Deprecated. Use Loader.concat_all. A pytd.TypeDeclUnit instance. It'll directly contain the builtin classes and functions, and submodules for each of the standard library modules. """ - return pytd_utils.Concat(*GetBuiltinsAndTyping()) + return pytd_utils.Concat(*GetBuiltinsAndTyping(True)) # pyi for a catch-all module @@ -68,5 +71,5 @@ def __getattr__(name: Any) -> Any: ... """ -def GetDefaultAst(): - return parser.parse_string(src=DEFAULT_SRC) +def GetDefaultAst(gen_stub_imports): + return parser.parse_string(src=DEFAULT_SRC, gen_stub_imports=gen_stub_imports) diff --git a/pytype/pytd/printer.py b/pytype/pytd/printer.py index 3bb9aa98a..238ea2b03 100644 --- a/pytype/pytd/printer.py +++ b/pytype/pytd/printer.py @@ -150,11 +150,16 @@ def EnterTypeDeclUnit(self, unit): for defn in definitions: self._local_names[defn.name] = label for alias in unit.aliases: + # Modules are represented as NamedTypes in partially resolved asts and + # sometimes as LateTypes in asts modified for pickling. if isinstance(alias.type, pytd.Module): - name = alias.name - if unit.name and name.startswith(unit.name + "."): - name = name[len(unit.name) + 1:] - self._module_aliases[alias.type.module_name] = name + module_name = alias.type.module_name + elif isinstance(alias.type, (pytd.NamedType, pytd.LateType)): + module_name = alias.type.name + else: + continue + name = self._StripUnitPrefix(alias.name) + self._module_aliases[module_name] = name def LeaveTypeDeclUnit(self, _): self._unit = None @@ -200,15 +205,20 @@ def VisitConstant(self, node): assert module == "builtins", module assert name in ("True", "False"), name return name - else: - return f"{node.name}: {node.type}" + + node_type = node.type + if node_type.startswith(("import ", "from ")): + # TODO(slebedev): Use types.ModuleType instead. + node_type = "module" + return f"{node.name}: {node_type}" def EnterAlias(self, _): self.old_imports = self.imports.copy() def VisitAlias(self, node): """Convert an import or alias to a string.""" - if isinstance(self.old_node.type, (pytd.NamedType, pytd.ClassType)): + if isinstance(self.old_node.type, + (pytd.NamedType, pytd.ClassType, pytd.LateType)): full_name = self.old_node.type.name suffix = "" module, _, name = full_name.rpartition(".") @@ -320,6 +330,9 @@ def VisitSignature(self, node): if node.return_type == "nothing": return_type = "NoReturn" # a prettier alias for nothing self._FromTyping(return_type) + elif node.return_type.startswith(("import ", "from ")): + # TODO(slebedev): Use types.ModuleType instead. + return_type = "module" else: return_type = node.return_type ret = f" -> {return_type}" @@ -418,6 +431,17 @@ def _UseExistingModuleAlias(self, name): suffix = f"{remainder}.{suffix}" return None + def _GuessModule(self, maybe_module): + """Guess which part of the given name is the module prefix.""" + if "." not in maybe_module: + return maybe_module + prefix, suffix = maybe_module.rsplit(".", 1) + # Heuristic: modules are typically lowercase, classes uppercase. + if suffix[0].islower(): + return maybe_module + else: + return self._GuessModule(prefix) + def VisitNamedType(self, node): """Convert a type to a string.""" prefix, _, suffix = node.name.rpartition(".") @@ -436,7 +460,7 @@ def VisitNamedType(self, node): if aliased_name: node_name = aliased_name else: - self._RequireImport(prefix) + self._RequireImport(self._GuessModule(prefix)) node_name = node.name else: node_name = node.name @@ -477,7 +501,10 @@ def VisitModule(self, node): # `import x.y as z` and `from x import y as z` are equivalent, but the # latter is a bit prettier. prefix, suffix = node.module_name.rsplit(".", 1) - return f"from {prefix} import {suffix} as {node.name}" + imp = f"from {prefix} import {suffix}" + if node.name != suffix: + imp += f" as {node.name}" + return imp else: return f"import {node.module_name} as {node.name}" diff --git a/pytype/pytd/pytd_visitors.py b/pytype/pytd/pytd_visitors.py index 9e77ed915..d8cee50cf 100644 --- a/pytype/pytd/pytd_visitors.py +++ b/pytype/pytd/pytd_visitors.py @@ -181,4 +181,5 @@ def VisitTypeParameter(self, node): VisitClass = _ReplaceModuleName # pylint: disable=invalid-name VisitFunction = _ReplaceModuleName # pylint: disable=invalid-name VisitStrictType = _ReplaceModuleName # pylint: disable=invalid-name + VisitModule = _ReplaceModuleName # pylint: disable=invalid-name VisitNamedType = _ReplaceModuleName # pylint: disable=invalid-name diff --git a/pytype/pytd/serialize_ast.py b/pytype/pytd/serialize_ast.py index 47288494d..aecdf9746 100644 --- a/pytype/pytd/serialize_ast.py +++ b/pytype/pytd/serialize_ast.py @@ -9,6 +9,7 @@ from pytype import utils from pytype.pyi import parser +from pytype.pytd import pytd from pytype.pytd import pytd_utils from pytype.pytd import visitors @@ -28,6 +29,35 @@ def EnterClassType(self, n): self.class_type_nodes.append(n) +class UndoModuleAliasesVisitor(visitors.Visitor): + """Visitor to undo module aliases in late types. + + Since late types are loaded out of context, they need to contain the original + names of modules, not whatever they've been aliased to in the current module. + """ + + def __init__(self): + super().__init__() + self._module_aliases = {} + + def EnterTypeDeclUnit(self, node): + for alias in node.aliases: + if isinstance(alias.type, pytd.Module): + name = utils.strip_prefix(alias.name, f"{node.name}.") + self._module_aliases[name] = alias.type.module_name + + def VisitLateType(self, node): + if "." not in node.name: + return node + prefix, suffix = node.name.rsplit(".", 1) + while prefix: + if prefix in self._module_aliases: + return node.Replace(name=self._module_aliases[prefix] + "." + suffix) + prefix, _, remainder = prefix.rpartition(".") + suffix = f"{remainder}.{suffix}" + return node + + SerializableTupleClass = collections.namedtuple( "_", ["ast", "dependencies", "late_dependencies", "class_type_nodes"]) @@ -67,6 +97,7 @@ def StoreAst(ast, filename=None, open_function=open): if ast.name.endswith(".__init__"): ast = ast.Visit(visitors.RenameModuleVisitor( ast.name, ast.name.rsplit(".__init__", 1)[0])) + ast = ast.Visit(UndoModuleAliasesVisitor()) # Collect dependencies deps = visitors.CollectDependencies() ast.Visit(deps) @@ -229,7 +260,8 @@ def PrepareForExport(module_name, ast, loader): # their own visitors so they can be applied without printing. src = pytd_utils.Print(ast) ast = parser.parse_string(src=src, name=module_name, - python_version=loader.python_version) + python_version=loader.python_version, + gen_stub_imports=loader.gen_stub_imports) ast = ast.Visit(visitors.LookupBuiltins(loader.builtins, full_names=False)) ast = ast.Visit(visitors.ExpandCompatibleBuiltins(loader.builtins)) ast = ast.Visit(visitors.LookupLocalTypes()) diff --git a/pytype/pytd/slots.py b/pytype/pytd/slots.py index f02f86613..537993386 100644 --- a/pytype/pytd/slots.py +++ b/pytype/pytd/slots.py @@ -5,10 +5,7 @@ mappings. """ -# TODO(b/175443170): pytype takes too long on this file. Once the linked bug is -# fixed, check if we can remove the skip-file. -# pytype: skip-file - +from typing import List TYPEOBJECT_PREFIX = "tp_" NUMBER_PREFIX = "nb_" @@ -49,7 +46,7 @@ def __init__(self, python_name, c_name, function_type, index=None, self.symbol = symbol -SLOTS = [ +SLOTS: List[Slot] = [ # typeobject Slot("__new__", "tp_new", "new"), Slot("__init__", "tp_init", "init"), diff --git a/pytype/pytd/typeshed.py b/pytype/pytd/typeshed.py index d02031b47..5fa0c01bf 100644 --- a/pytype/pytd/typeshed.py +++ b/pytype/pytd/typeshed.py @@ -344,13 +344,14 @@ def _get_typeshed(): return _typeshed -def parse_type_definition(pyi_subdir, module, python_version): +def parse_type_definition(pyi_subdir, module, python_version, gen_stub_imports): """Load and parse a *.pyi from typeshed. Args: pyi_subdir: the directory where the module should be found. module: the module name (without any file extension) python_version: sys.version_info[:2] + gen_stub_imports: Temporary flag for releasing --gen-stub-imports. Returns: None if the module doesn't have a definition. @@ -365,5 +366,6 @@ def parse_type_definition(pyi_subdir, module, python_version): return None ast = parser.parse_string(src, filename=filename, name=module, - python_version=python_version) + python_version=python_version, + gen_stub_imports=gen_stub_imports) return filename, ast diff --git a/pytype/pytd/typeshed_test.py b/pytype/pytd/typeshed_test.py index ad70bea0f..4f7102642 100644 --- a/pytype/pytd/typeshed_test.py +++ b/pytype/pytd/typeshed_test.py @@ -30,7 +30,7 @@ def test_get_typeshed_dir(self): def test_parse_type_definition(self): filename, ast = typeshed.parse_type_definition( - "stdlib", "_random", self.python_version) + "stdlib", "_random", self.python_version, True) self.assertEqual(os.path.basename(filename), "_random.pyi") self.assertIn("_random.Random", [cls.name for cls in ast.classes]) diff --git a/pytype/pytd/visitors.py b/pytype/pytd/visitors.py index f37173753..9d85dae29 100644 --- a/pytype/pytd/visitors.py +++ b/pytype/pytd/visitors.py @@ -559,17 +559,25 @@ def _HandleDuplicates(self, new_aliases): Raises: KeyError: If there is a name clash. """ + def SameModuleName(a, b): + return ( + isinstance(a.type, pytd.Module) and + isinstance(b.type, pytd.Module) and + a.type.module_name == b.type.module_name + ) + name_to_alias = {} out = [] for a in new_aliases: - if a.name in name_to_alias: - existing = name_to_alias[a.name] - if existing != a: - raise KeyError("Duplicate top level items: %r, %r" % ( - existing.type.name, a.type.name)) - else: + if a.name not in name_to_alias: name_to_alias[a.name] = a out.append(a) + continue + existing = name_to_alias[a.name] + if existing == a or SameModuleName(existing, a): + continue + raise KeyError("Duplicate top level items: %r, %r" % ( + existing.type.name, a.type.name)) return out def EnterTypeDeclUnit(self, node): diff --git a/pytype/pytd/visitors_test.py b/pytype/pytd/visitors_test.py index adfb83542..17f1fe7fb 100644 --- a/pytype/pytd/visitors_test.py +++ b/pytype/pytd/visitors_test.py @@ -430,7 +430,7 @@ def f(x: Union[int, slice]) -> List[Any]: ... def g(x: foo.C.C2) -> None: ... """) expected = textwrap.dedent(""" - import foo.C + import foo from typing import Any, List, Union def f(x: Union[int, slice]) -> List[Any]: ... @@ -438,7 +438,7 @@ def g(x: foo.C.C2) -> None: ... """).strip() tree = self.Parse(src) res = pytd_utils.Print(tree) - self.AssertSourceEquals(res, src) + self.AssertSourceEquals(res, expected) self.assertMultiLineEqual(res, expected) def test_print_imports_named_type(self): diff --git a/pytype/stubs/builtins/attr/_compat.pytd b/pytype/stubs/builtins/attr/_compat.pytd index ceedca495..67a969775 100644 --- a/pytype/stubs/builtins/attr/_compat.pytd +++ b/pytype/stubs/builtins/attr/_compat.pytd @@ -1,11 +1,14 @@ import __future__ +import collections +import platform import sys +import types +import warnings if sys.version_info < (3,): import UserDict from typing import Mapping, NoReturn else: import types -import collections from typing import Any, Type OrderedDict: Type[collections.OrderedDict] @@ -15,12 +18,8 @@ TYPE: str absolute_import: __future__._Feature division: __future__._Feature ordered_dict: Type[dict] -platform: module print_function: __future__._Feature set_closure_cell: Any -sys: module -types: module -warnings: module if sys.version_info < (3,): IterableUserDict: Type[UserDict.IterableUserDict] diff --git a/pytype/stubs/builtins/attr/_make.pytd b/pytype/stubs/builtins/attr/_make.pytd index 621d09886..2507b7fa6 100644 --- a/pytype/stubs/builtins/attr/_make.pytd +++ b/pytype/stubs/builtins/attr/_make.pytd @@ -40,8 +40,6 @@ linecache: module ordered_dict: Type[dict] print_function: __future__._Feature set_closure_cell: Any -sys: module -threading: module warnings: module _T0 = TypeVar('_T0') diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt index 7e9188e9d..27a6a284e 100644 --- a/pytype/tests/CMakeLists.txt +++ b/pytype/tests/CMakeLists.txt @@ -16,6 +16,7 @@ py_library( DEPS .test_utils pytype.libvm + pytype.utils pytype.pyi.parser pytype.pytd.pytd ) @@ -642,8 +643,6 @@ py_test( test_reingest2.py DEPS .test_base - pytype.utils - pytype.pytd.pytd ) py_test( diff --git a/pytype/tests/test_abc2.py b/pytype/tests/test_abc2.py index e3e1c2615..c9cfd7b5b 100644 --- a/pytype/tests/test_abc2.py +++ b/pytype/tests/test_abc2.py @@ -70,7 +70,6 @@ def foo(self): self.assertTypesMatchPytd(ty, """ import abc from typing import Annotated, Any - abc = ... # type: module v1 = ... # type: Any v2 = ... # type: int class Bar(Foo): diff --git a/pytype/tests/test_annotations.py b/pytype/tests/test_annotations.py index 7d90306d7..265d95490 100644 --- a/pytype/tests/test_annotations.py +++ b/pytype/tests/test_annotations.py @@ -53,8 +53,8 @@ def foo(x: typing.Union[int, float], y: int): return x + y """) self.assertTypesMatchPytd(ty, """ + import typing from typing import Union - typing = ... # type: module def foo(x: Union[int, float], y:int) -> Union[int, float]: ... """) @@ -163,8 +163,8 @@ def f(c: "calendar.Calendar") -> int: return c.getfirstweekday() """) self.assertTypesMatchPytd(ty, """ - typing = ... # type: module - calendar = ... # type: module + import calendar + import typing def f(c: calendar.Calendar) -> int: ... """) @@ -436,8 +436,8 @@ def f(x: A): return x.name """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import Any - a = ... # type: module A = ... # type: Any def f(x) -> Any: ... """) @@ -887,9 +887,8 @@ class A: """, deep=False) self.assertTypesMatchPytd(ty, """ from typing import List - import typing - def f(x: typing.List[A]) -> int: ... + def f(x: List[A]) -> int: ... class A: ... """) @@ -906,7 +905,6 @@ class A: """, deep=False) self.assertTypesMatchPytd(ty, """ from typing import List - import typing ListA = ... # type: str TypeA = ... # type: str def f(x: typing.List[A]) -> int: ... diff --git a/pytype/tests/test_anystr1.py b/pytype/tests/test_anystr1.py index 081bab56a..1b3a0cc6d 100644 --- a/pytype/tests/test_anystr1.py +++ b/pytype/tests/test_anystr1.py @@ -21,7 +21,7 @@ def f(x: AnyStr) -> AnyStr: ... y = 3 """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: int y = ... # type: int """) diff --git a/pytype/tests/test_attr1.py b/pytype/tests/test_attr1.py index 561af5177..91aeb0031 100644 --- a/pytype/tests/test_attr1.py +++ b/pytype/tests/test_attr1.py @@ -18,8 +18,8 @@ class Foo: z = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: Any @@ -37,7 +37,7 @@ class Foo: x = attr.ib(type=A) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr class A: ... @attr.s class Foo: @@ -54,8 +54,8 @@ class Foo: x = attr.ib(type=List[int]) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import List - attr: module @attr.s class Foo: x: List[int] @@ -71,8 +71,8 @@ class Foo: x = attr.ib(type=Union[str, int]) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Union - attr: module @attr.s class Foo: x: Union[str, int] @@ -89,8 +89,8 @@ class Foo: y = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Union - attr: module @attr.s class Foo: x: Union[str, int] @@ -107,7 +107,7 @@ class Foo: y = attr.ib() # type: str """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: Foo @@ -123,7 +123,7 @@ class Foo: x = attr.ib(type='Foo') """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: Foo @@ -140,7 +140,7 @@ class Foo: z = 1 # class var, should not be in __init__ """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -180,7 +180,7 @@ class Foo: ___z = attr.ib(type=int) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: _x: int @@ -200,7 +200,7 @@ class Foo: a = attr.ib(type=str, default=None) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -222,7 +222,7 @@ class Foo: y = attr.ib(default=42) # type: str # annotation-type-mismatch[e] """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -242,8 +242,8 @@ class Foo: y = attr.ib(factory=CustomClass) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import List - attr: module class CustomClass: ... @attr.s class Foo: @@ -265,8 +265,8 @@ class Foo: y = attr.ib(factory=unannotated_func) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any, Dict - attr: module class CustomClass: ... def unannotated_func() -> CustomClass: ... @attr.s @@ -284,8 +284,8 @@ class Foo: x = attr.ib(default=attr.Factory(list)) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import List - attr: module @attr.s class Foo: x: list @@ -320,7 +320,7 @@ class Foo: x = attr.ib(default=attr.Factory(len, takes_self=True)) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -335,8 +335,8 @@ class Foo: x = attr.ib(default=None) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: Any @@ -353,7 +353,7 @@ class Foo: x = Foo([]).x """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: list @@ -399,7 +399,7 @@ class Foo: y = attr.ib() # type: int """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: str @@ -463,7 +463,7 @@ class C(A, B): c = attr.ib() # type: int """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class A: a: int @@ -493,7 +493,7 @@ class C(A, B): c = attr.ib() # type: int """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class A: a: int @@ -523,7 +523,7 @@ class C(A, B): c = attr.ib() # type: int """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class A: a: int @@ -546,8 +546,8 @@ class Foo(__any_object__): a = attr.ib() # type: int """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo(Any): a: int @@ -582,8 +582,8 @@ def default_c(self): return 10 """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: a: int @@ -621,8 +621,8 @@ def default_c(self): return self.b """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: a: int @@ -670,8 +670,8 @@ class Foo: Foo(x=0) # should not be an error """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any, List - attr: module FACTORIES: List[nothing] @attr.s class Foo: @@ -687,7 +687,7 @@ class Foo: x = attr.ib(default=()) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: tuple @@ -789,8 +789,8 @@ class Foo: z = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: Any @@ -809,8 +809,8 @@ class Foo: z = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: Any diff --git a/pytype/tests/test_attr2.py b/pytype/tests/test_attr2.py index 079303ea9..bc2d019e0 100644 --- a/pytype/tests/test_attr2.py +++ b/pytype/tests/test_attr2.py @@ -20,7 +20,7 @@ class Foo: x = attr.ib(factory=annotated_func) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr class CustomClass: ... def annotated_func() -> CustomClass: ... @attr.s @@ -241,7 +241,7 @@ class Foo: y = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -258,7 +258,7 @@ class Foo: y = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: Foo @@ -276,7 +276,7 @@ class Foo: z : int = 1 # class var, should not be in __init__ """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -304,7 +304,7 @@ class Foo: y: str = attr.ib(default=42) # annotation-type-mismatch[e] """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -329,9 +329,9 @@ def decorate(cls: Type[Foo]) -> Type[Foo]: ... class Bar(foo.Foo): ... """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import attr + import foo from typing import Type - attr: module - foo: module Bar: Type[foo.Foo] """) @@ -364,8 +364,8 @@ class Foo: z = attr.ib(type=str, default="hello") """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: int @@ -389,8 +389,8 @@ class Foo(Generic[T]): x2, y2 = foo2.x, foo2.y """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Generic, TypeVar - attr: module T = TypeVar('T') @attr.s class Foo(Generic[T]): @@ -420,8 +420,8 @@ class Foo(Generic[T]): x2 = foo2.x """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Generic, TypeVar - attr: module T = TypeVar('T') @attr.s class Foo(Generic[T]): @@ -459,8 +459,8 @@ class Foo: z = attr.ib(type=str) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: Any @@ -477,8 +477,8 @@ class Foo: x = attr.ib(default=1) """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: int @@ -497,7 +497,7 @@ class Foo: a: str = 'hello' """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -518,7 +518,7 @@ class Foo: x: str = 'hello' """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: y: int @@ -550,8 +550,8 @@ def f(self): pass """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any, Annotated - attr: module @attr.s class Foo: y: str @@ -577,8 +577,8 @@ class Foo: """) self.assertTypesMatchPytd( ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: x: Any @@ -614,7 +614,7 @@ def get_y(self): return self.y """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: bool @@ -641,8 +641,8 @@ class Bar: baz = attr.ib() """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: foo: str @@ -665,8 +665,8 @@ class Foo: y: str = 'hello' """) self.assertTypesMatchPytd(ty, """ + import attr from typing import ClassVar - attr: module @attr.s class Foo: y: str @@ -684,8 +684,8 @@ class Foo: x: int """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Callable - attr: module def s(*args, **kwargs) -> Callable: ... @attr.s class Foo: @@ -718,8 +718,8 @@ class Foo: """) self.assertTypesMatchPytd( ty, """ + import attr from typing import Any - attr: module @attr.s(auto_attribs=True) class Foo: x: Any @@ -749,8 +749,8 @@ class Foo: """) self.assertTypesMatchPytd( ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: y: int @@ -773,7 +773,7 @@ class Foo: a: str = 'hello' """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: int @@ -794,7 +794,7 @@ class Foo: x: str = 'hello' """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: y: int @@ -826,8 +826,8 @@ def f(self): pass """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any, Annotated - attr: module @attr.s class Foo: y: str @@ -856,7 +856,7 @@ def get_y(self): return self.y """) self.assertTypesMatchPytd(ty, """ - attr: module + import attr @attr.s class Foo: x: bool @@ -883,8 +883,8 @@ class Bar: baz = attr.ib() """) self.assertTypesMatchPytd(ty, """ + import attr from typing import Any - attr: module @attr.s class Foo: foo: str @@ -907,8 +907,8 @@ class Foo: y: str = 'hello' """) self.assertTypesMatchPytd(ty, """ + import attr from typing import ClassVar - attr: module @attr.s class Foo: y: str @@ -980,8 +980,8 @@ class Foo(foo.A): z: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - attr: module - foo: module + import attr + import foo @attr.s class Foo(foo.A): z: str @@ -1009,8 +1009,8 @@ class Foo(foo.B): a: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - attr: module - foo: module + import attr + import foo @attr.s class Foo(foo.B): a: str @@ -1041,8 +1041,8 @@ class Foo(foo.B): a: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - attr: module - foo: module + import attr + import foo @attr.s class Foo(foo.B): a: str @@ -1067,8 +1067,8 @@ class Foo(foo.A): z: str """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - attr: module - foo: module + import attr + import foo @attr.s class Foo(foo.A): z: str @@ -1093,8 +1093,8 @@ class Foo: x: int """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Annotated, Callable - foo: module @attr.s class Foo: diff --git a/pytype/tests/test_attributes1.py b/pytype/tests/test_attributes1.py index 806beeaca..63ffa7cf8 100644 --- a/pytype/tests/test_attributes1.py +++ b/pytype/tests/test_attributes1.py @@ -780,8 +780,8 @@ def f(): return foo.f().x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Optional - foo: module def f() -> Optional[str]: ... """) diff --git a/pytype/tests/test_base.py b/pytype/tests/test_base.py index 26816795e..1be54ccd4 100644 --- a/pytype/tests/test_base.py +++ b/pytype/tests/test_base.py @@ -1,5 +1,6 @@ """Common methods for tests of analyze.py.""" +import contextlib import logging import sys import textwrap @@ -8,6 +9,7 @@ from pytype import analyze from pytype import config from pytype import directors +from pytype import file_utils from pytype import load_pytd from pytype.pyi import parser from pytype.pytd import optimize @@ -132,6 +134,7 @@ def setUp(self): self.options = config.Options.create(python_version=self.python_version, allow_recursive_types=True, build_dict_literals_from_kwargs=True, + gen_stub_imports=True, strict_namedtuple_checks=True, use_enum_overlay=True) @@ -148,6 +151,16 @@ def ConfigureOptions(self, **kwargs): "Individual tests cannot set the python_version of the config options.") self.options.tweak(**kwargs) + def _GetPythonpath(self, pythonpath, imports_map): + if pythonpath: + return pythonpath + elif imports_map: + return [""] + elif self.options.pythonpath: + return self.options.pythonpath + else: + return pythonpath + # For historical reasons (byterun), this method name is snakecase: # pylint: disable=invalid-name def Check(self, code, pythonpath=(), skip_repeat_calls=True, @@ -156,7 +169,7 @@ def Check(self, code, pythonpath=(), skip_repeat_calls=True, """Run an inference smoke test for the given code.""" self.ConfigureOptions( skip_repeat_calls=skip_repeat_calls, - pythonpath=[""] if (not pythonpath and imports_map) else pythonpath, + pythonpath=self._GetPythonpath(pythonpath, imports_map), quick=quick, imports_map=imports_map) try: src = _Format(code) @@ -180,7 +193,7 @@ def _SetUpErrorHandling(self, code, pythonpath, analyze_annotated, quick, code = _Format(code) errorlog = test_utils.TestErrorLog(code) self.ConfigureOptions( - pythonpath=[""] if (not pythonpath and imports_map) else pythonpath, + pythonpath=self._GetPythonpath(pythonpath, imports_map), analyze_annotated=analyze_annotated, quick=quick, imports_map=imports_map) return {"src": code, "errorlog": errorlog, "options": self.options, @@ -373,7 +386,7 @@ def _InferAndVerify( """ self.ConfigureOptions( module_name=module_name, quick=quick, use_pickled_files=True, - pythonpath=[""] if (not pythonpath and imports_map) else pythonpath, + pythonpath=self._GetPythonpath(pythonpath, imports_map), imports_map=imports_map, analyze_annotated=analyze_annotated) errorlog = test_utils.TestErrorLog(src) if errorlog.expected: @@ -417,6 +430,30 @@ def assertTypesMatchPytd(self, ty, pytd_src): # (In other words, display a change from "working" to "broken") self.assertMultiLineEqual(pytd_tree_src, ty_src) + @contextlib.contextmanager + def DepTree(self, deps): + old_pythonpath = self.options.pythonpath + try: + with file_utils.Tempdir() as d: + self.ConfigureOptions(pythonpath=[d.path]) + for dep in deps: + if len(dep) == 3: + path, contents, opts = dep + else: + path, contents = dep + opts = {} + if path.endswith(".pyi"): + d.create_file(path, contents) + elif path.endswith(".py"): + path = path + "i" + pyi = pytd_utils.Print(self.Infer(contents, **opts)) + d.create_file(path, pyi) + else: + raise ValueError(f"Unrecognised dependency type: {path}") + yield d + finally: + self.ConfigureOptions(pythonpath=old_pythonpath) + def _PrintErrorDebug(descr, value): log.error("=============== %s ===========", descr) diff --git a/pytype/tests/test_base_test.py b/pytype/tests/test_base_test.py index 210d1bff8..4228c2281 100644 --- a/pytype/tests/test_base_test.py +++ b/pytype/tests/test_base_test.py @@ -1,5 +1,7 @@ """Tests for our test framework.""" +import os + from pytype import file_utils from pytype.tests import test_base from pytype.tests import test_utils @@ -168,5 +170,25 @@ def test_skip_from_py(self): """) +class DepTreeTest(test_base.BaseTest): + + def test_dep_tree(self): + foo_pyi = """ + class A: pass + """ + bar_py = """ + import foo + x = foo.A() + """ + deps = [("foo.pyi", foo_pyi), ("bar.py", bar_py)] + with self.DepTree(deps) as d: + self.Check(""" + import foo + import bar + assert_type(bar.x, foo.A) + """) + self.assertCountEqual(os.listdir(d.path), ["foo.pyi", "bar.pyi"]) + + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_basic2.py b/pytype/tests/test_basic2.py index 90dd9dd4b..b1bdf02f9 100644 --- a/pytype/tests/test_basic2.py +++ b/pytype/tests/test_basic2.py @@ -23,8 +23,7 @@ def test_import_shadowed(self): "signal" ]: ty = self.Infer("import %s" % module) - expected = " %s = ... # type: module" % module - self.assertTypesMatchPytd(ty, expected) + self.assertTypesMatchPytd(ty, f"import {module}") def test_cleanup(self): ty = self.Infer(""" diff --git a/pytype/tests/test_builtins1.py b/pytype/tests/test_builtins1.py index 1947fb013..92e0cba9f 100644 --- a/pytype/tests/test_builtins1.py +++ b/pytype/tests/test_builtins1.py @@ -334,7 +334,7 @@ def f(): signal.signal(signal.SIGALRM, 0) """) self.assertTypesMatchPytd(ty, """ - signal = ... # type: module + import signal def f() -> NoneType: ... """) @@ -347,7 +347,7 @@ def args(): args() """, deep=False, show_library_calls=True) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys def args() -> str: ... """) @@ -380,7 +380,7 @@ def __init__(self): self.bar = array.array('i', [1, 2, 3]) """) self.assertTypesMatchPytd(ty, """ - array = ... # type: module + import array class Foo: bar = ... # type: array.array[int] def __init__(self) -> None: ... @@ -426,8 +426,8 @@ def f(x): return 3j """) self.assertTypesMatchPytd(ty, """ + import time from typing import Union - time = ... # type: module def f(x) -> Union[complex, float]: ... """) @@ -469,7 +469,7 @@ def f(): return 'py%d' % sys.version_info[0] """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys def f() -> str: ... """) @@ -487,7 +487,7 @@ class Foo( self.python_version, strict_namedtuple_checks=self.options.strict_namedtuple_checks) expected = pytd_utils.Print(ast) + textwrap.dedent(""" - collections = ... # type: module + import collections class Foo({name}): ...""").format(name=name) self.assertTypesMatchPytd(ty, expected) @@ -508,7 +508,7 @@ def test_store_and_load_from_namedtuple(self): self.python_version, strict_namedtuple_checks=self.options.strict_namedtuple_checks) expected = pytd_utils.Print(ast) + textwrap.dedent(""" - collections = ... # type: module + import collections t = {name} x = ... # type: int y = ... # type: str @@ -532,8 +532,8 @@ def f(mod): return type(mod) == types.ModuleType """) self.assertTypesMatchPytd(ty, """ + import types from typing import Any - types = ... # type: module def f(mod) -> Any: ... """) @@ -545,8 +545,8 @@ def f(date): return date.ctime() """) self.assertTypesMatchPytd(ty, """ + import datetime from typing import Any - datetime = ... # type: module def f(date) -> Any: ... """) @@ -558,7 +558,7 @@ def f(tz): tz.fromutc(datetime.datetime(1929, 10, 29)) """) self.assertTypesMatchPytd(ty, """ - datetime = ... # type: module + import datetime def f(tz) -> NoneType: ... """) diff --git a/pytype/tests/test_builtins2.py b/pytype/tests/test_builtins2.py index 5b86baf54..628ea9f7b 100644 --- a/pytype/tests/test_builtins2.py +++ b/pytype/tests/test_builtins2.py @@ -28,7 +28,7 @@ def test_defaultdict(self): r[3] = 3 """) self.assertTypesMatchPytd(ty, """ - collections = ... # type: module + import collections r = ... # type: collections.defaultdict[int, int] """) @@ -47,7 +47,7 @@ def test_import_lib(self): import importlib """) self.assertTypesMatchPytd(ty, """ - importlib = ... # type: module + import importlib """) def test_set_union(self): @@ -185,8 +185,8 @@ def test_module(self): y = foo.x.baz """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module x = ... # type: str y = ... # type: Any """) diff --git a/pytype/tests/test_builtins4.py b/pytype/tests/test_builtins4.py index 38566283b..3f0748119 100644 --- a/pytype/tests/test_builtins4.py +++ b/pytype/tests/test_builtins4.py @@ -157,8 +157,8 @@ def j(x: Type[super]): ... v = super """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Type - foo = ... # type: module def f(x) -> None: ... def g(x: object) -> None: ... def h(x: Any) -> None: ... @@ -207,8 +207,8 @@ def f2(x: 'collections.OrderedDict[int, str]'): return x.copy() """) self.assertTypesMatchPytd(ty, """ + import collections from typing import Dict - collections: module def f1(x: Dict[int, str]) -> Dict[int, str]: ... def f2( x: collections.OrderedDict[int, str] @@ -399,8 +399,8 @@ def f(x: int): x6 = filter(re.compile("").search, ("",)) """) self.assertTypesMatchPytd(ty, """ + import re from typing import Iterator - re: module def f(x: int) -> None: ... x1 = ... # type: Iterator[int] x2 = ... # type: Iterator[bool, ...] @@ -579,7 +579,7 @@ def t_testTobytes(): return array.array('B').tobytes() """) self.assertTypesMatchPytd(ty, """ - array = ... # type: module + import array def t_testTobytes() -> bytes: ... """) diff --git a/pytype/tests/test_chex_overlay.py b/pytype/tests/test_chex_overlay.py index a0c584389..e12b8cd21 100644 --- a/pytype/tests/test_chex_overlay.py +++ b/pytype/tests/test_chex_overlay.py @@ -44,9 +44,9 @@ class Foo: x: int """) self.assertTypesMatchPytd(ty, """ + import chex import dataclasses from typing import Dict, Mapping, TypeVar - chex: module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass class Foo(Mapping, object): @@ -67,9 +67,9 @@ class Foo: x: int """) self.assertTypesMatchPytd(ty, """ + import chex import dataclasses from typing import Dict, TypeVar - chex: module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass class Foo: @@ -105,9 +105,9 @@ class Foo: foo = Foo(0).replace(x=5) """) self.assertTypesMatchPytd(ty, """ + import chex import dataclasses from typing import Dict, Mapping, TypeVar - chex: module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass class Foo(Mapping, object): @@ -130,9 +130,9 @@ class Foo: foo = Foo.from_tuple((0,)) """) self.assertTypesMatchPytd(ty, """ + import chex import dataclasses from typing import Dict, Mapping, TypeVar - chex: module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass class Foo(Mapping, object): @@ -155,9 +155,9 @@ class Foo: tup = Foo(0).to_tuple() """) self.assertTypesMatchPytd(ty, """ + import chex import dataclasses from typing import Dict, Mapping, TypeVar - chex: module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass class Foo(Mapping, object): diff --git a/pytype/tests/test_classes1.py b/pytype/tests/test_classes1.py index 758de100a..089cead10 100644 --- a/pytype/tests/test_classes1.py +++ b/pytype/tests/test_classes1.py @@ -141,8 +141,8 @@ class Bar(Foo, a.A): x = Bar(duration=0) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import Any - a = ... # type: module class Foo: pass class Bar(Foo, Any): @@ -550,7 +550,7 @@ def f(): return foo.Foo().foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def f() -> str: ... """) @@ -578,7 +578,7 @@ def __getattribute__(self, name) -> int: ... x = a.A().x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: int """) @@ -595,7 +595,7 @@ class C(a.A): name = C.__name__ """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a class C(a.A): pass name = ... # type: str @@ -622,8 +622,8 @@ class B(menum.IntEnum): name2 = B.x.name """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import menum from typing import Any - menum = ... # type: module class A(menum.Enum): x = ... # type: int class B(menum.IntEnum): @@ -653,8 +653,8 @@ def h() -> Type[Union[int, B]]: ... x3 = a.h().x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import Union - a = ... # type: module x1 = ... # type: int x2 = ... # type: Union[int, str] x3 = ... # type: str @@ -673,7 +673,7 @@ class B: x = a.B.MyA() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: a.A """) @@ -708,7 +708,7 @@ def __new__(cls): x3 = object.__new__(bool) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a class C: def __new__(cls) -> str: ... x1 = ... # type: a.B @@ -776,7 +776,7 @@ def __init__(self, x, y) -> None: ... x2 = a.B() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x1 = ... # type: a.A[int] x2 = ... # type: a.A[str] """) @@ -863,7 +863,7 @@ class X(metaclass=A): ... v = a.X.f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a v = ... # type: float """) @@ -882,8 +882,8 @@ class B(metaclass=A): ... x = b.B.x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import b from typing import Any - b = ... # type: module x = ... # type: Any """) @@ -940,7 +940,7 @@ class C(six.with_metaclass(A, object)): x2 = C.x """) self.assertTypesMatchPytd(ty, """ - six: module + import six class A(type): x: int def __init__(self, name, bases, members) -> None: ... @@ -1275,7 +1275,7 @@ class DL(get_base()): pass """) self.assertTypesMatchPytd(ty, """ from typing import List - typing = ... # type: module + import typing class DL(List[str]): def get_len(self) -> int: ... def get_base() -> type: ... @@ -1361,9 +1361,8 @@ class Y: ... Y = foo.X.Y """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - from typing import Type import foo - foo: module + from typing import Type Y: Type[foo.X.Y] """) d.create_file("bar.pyi", pytd_utils.Print(ty)) @@ -1372,9 +1371,8 @@ class Y: ... Y = bar.Y """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import bar from typing import Type - import foo - bar: module Y: Type[foo.X.Y] """) @@ -1392,7 +1390,6 @@ class Y: ... self.assertTypesMatchPytd(ty, """ from typing import Type import foo - foo: module Z: Type[foo.X.Y] """) @@ -1410,7 +1407,6 @@ class Z: ... self.assertTypesMatchPytd(ty, """ from typing import Type import foo - foo: module Z: Type[foo.X.Y.Z] """) diff --git a/pytype/tests/test_classes2.py b/pytype/tests/test_classes2.py index 6d8b99834..01c89819b 100644 --- a/pytype/tests/test_classes2.py +++ b/pytype/tests/test_classes2.py @@ -325,7 +325,6 @@ def fooTest(self): """) self.assertTypesMatchPytd(ty, """ import unittest - unittest = ... # type: module class A(unittest.case.TestCase): x = ... # type: int def fooTest(self) -> int: ... @@ -344,7 +343,6 @@ def fooTest(self): """) self.assertTypesMatchPytd(ty, """ import unittest - unittest = ... # type: module class A(unittest.case.TestCase): x = ... # type: int def setUp(self) -> None : ... @@ -366,7 +364,6 @@ def fooTest(self): """) self.assertTypesMatchPytd(ty, """ import unittest - unittest = ... # type: module class A(unittest.case.TestCase): x = ... # type: int foo = ... # type: str @@ -422,8 +419,8 @@ def f(self): v = X().f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Type - foo = ... # type: module class X(metaclass=foo.MyMeta): def f(self) -> int: ... v = ... # type: int diff --git a/pytype/tests/test_coroutine.py b/pytype/tests/test_coroutine.py index b6f3a448a..7fe243d4a 100644 --- a/pytype/tests/test_coroutine.py +++ b/pytype/tests/test_coroutine.py @@ -118,11 +118,10 @@ async def caller(): return x """) self.assertTypesMatchPytd(ty, """ + import asyncio + import types from typing import Any, Coroutine, Union - asyncio: module - types: module - def caller() -> Coroutine[Any, Any, Union[int, str]]: ... def f1() -> Coroutine[Any, Any, None]: ... def f2() -> Coroutine[Any, Any, None]: ... @@ -190,10 +189,9 @@ async def f3(): await f2(c2()) """) self.assertTypesMatchPytd(ty, """ + import types from typing import Any, Awaitable, Coroutine, TypeVar - types: module - _TBaseAwaitable = TypeVar('_TBaseAwaitable', bound=BaseAwaitable) class BaseAwaitable: @@ -436,10 +434,9 @@ async def func2(x: Coroutine[Any, Any, str]): func2(foo.f1()) """, deep=True, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Awaitable, Coroutine, List - foo: module - def func1(x: Awaitable[str]) -> Coroutine[Any, Any, List[str]]: ... def func2(x: Coroutine[Any, Any, str]) -> Coroutine[Any, Any, List[str]]: ... """) @@ -475,8 +472,8 @@ async def tcp_echo_client(message): return await asyncio.open_connection( '127.0.0.1', 8888) """) self.assertTypesMatchPytd(ty, """ + import asyncio from typing import Any, Coroutine, Tuple - asyncio: module def tcp_echo_client(message) -> Coroutine[ Any, Any, Tuple[asyncio.streams.StreamReader, asyncio.streams.StreamWriter]]: ... @@ -494,8 +491,8 @@ async def main(): worker(queue) """) self.assertTypesMatchPytd(ty, """ + import asyncio from typing import Any, Coroutine - asyncio: module def worker(queue) -> coroutine: ... def main() -> Coroutine[Any, Any, None]: ... """) @@ -512,9 +509,8 @@ async def call_foo(): return await future """) self.assertTypesMatchPytd(ty, """ - import asyncio.futures + import asyncio from typing import Any, Coroutine, Optional - asyncio: module def foo() -> Coroutine[Any, Any, int]: ... def call_foo() -> Coroutine[Any, Any, Optional[int]]: ... """) @@ -529,8 +525,8 @@ async def f() -> int: ... c = foo.f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Coroutine - foo: module c: Coroutine[Any, Any, int] """) diff --git a/pytype/tests/test_dataclasses.py b/pytype/tests/test_dataclasses.py index 1b31a90c9..f5bcf7d37 100644 --- a/pytype/tests/test_dataclasses.py +++ b/pytype/tests/test_dataclasses.py @@ -17,8 +17,8 @@ class Foo: z: str """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -37,8 +37,8 @@ class Foo: y: str """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: x: Foo @@ -59,8 +59,8 @@ class Foo: y = 10 """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: x: str @@ -80,8 +80,8 @@ def x(self): # annotation-type-mismatch[e] return 10 """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: x: str @@ -102,8 +102,8 @@ class Foo: z: str """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -124,8 +124,8 @@ def __init__(self, a: bool): self.y = 0 """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -144,8 +144,8 @@ class Foo: y: List[int] = dataclasses.field(default_factory=list) """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, List, Union - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -210,8 +210,8 @@ class Foo: y: int = dataclasses.field(init=False) """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -229,8 +229,8 @@ class Foo: y: int """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -287,8 +287,8 @@ def get_y(self): return self.y """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: w: float @@ -316,8 +316,8 @@ class Bar(Foo): z: bool = dataclasses.field(default=True) """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class Foo: w: float @@ -347,8 +347,8 @@ class C(B, A): c: int """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class A: a: int @@ -435,8 +435,8 @@ def get_value(x: Root): self.assertTypesMatchPytd(ty, """ from typing import Dict, Optional, Union + import dataclasses Node = Union[IntLeaf, StrLeaf, Tree] - dataclasses: module @dataclasses.dataclass class IntLeaf: @@ -490,8 +490,8 @@ class A: y: int = 10 """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class A: y: int @@ -509,8 +509,8 @@ class A: y: int = 10 """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class A: x: dataclasses.InitVar[str] @@ -533,8 +533,8 @@ class Foo: pass """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class A: x: dataclasses.InitVar[str] @@ -559,8 +559,8 @@ class B(A): z: dataclasses.InitVar[int] = 42 """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Union - dataclasses: module @dataclasses.dataclass class A: y: int @@ -585,8 +585,8 @@ class Foo: y: str = 'hello' """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import ClassVar, Dict - dataclasses: module @dataclasses.dataclass class Foo: y: str @@ -610,8 +610,8 @@ class Inner: Inner2 = Bar.Inner """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Type - dataclasses: module class Foo: Inner: Type[Inner1] class Bar: @@ -649,8 +649,8 @@ class Foo: y: int = field_wrapper(default=1) """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Any, Dict - dataclasses: module def field_wrapper(**kwargs) -> Any: ... @dataclasses.dataclass class Foo: @@ -672,8 +672,8 @@ def z(self) -> str: return "hello world" """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Annotated, Dict - dataclasses: module @dataclasses.dataclass class Foo: x: bool @@ -697,8 +697,8 @@ class Foo(Generic[T]): x2 = foo2.x """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Dict, Generic, TypeVar - dataclasses: module T = TypeVar('T') @dataclasses.dataclass class Foo(Generic[T]): @@ -786,9 +786,9 @@ class Foo(foo.A): z: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Dict, Union - dataclasses: module - foo: module @dataclasses.dataclass class Foo(foo.A): z: str @@ -817,9 +817,9 @@ class Foo(foo.B): a: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Dict, Union - dataclasses: module - foo: module @dataclasses.dataclass class Foo(foo.B): a: str @@ -851,9 +851,9 @@ class Foo(foo.B): a: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Dict, Union - dataclasses: module - foo: module @dataclasses.dataclass class Foo(foo.B): a: str @@ -885,9 +885,9 @@ class Foo(foo.B): a: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Dict, Union - dataclasses: module - foo: module @dataclasses.dataclass class Foo(foo.B): a: str @@ -913,9 +913,9 @@ class Foo(foo.A): z: str = "hello" """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Dict, Union - dataclasses: module - foo: module @dataclasses.dataclass class Foo(foo.A): z: str @@ -945,9 +945,9 @@ def b(self) -> int: return 42 """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Annotated, Dict, Union - dataclasses: module - foo: module @dataclasses.dataclass class Foo(foo.A): a: str @@ -976,8 +976,8 @@ class A: c = y.z """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Optional, List - foo: module x: foo.A y: foo.A a: Optional[foo.A] @@ -1004,9 +1004,9 @@ class B(foo.A): w: int """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Any, Dict, List, Union - dataclasses: module - foo: module @dataclasses.dataclass class B(foo.A): w: int @@ -1033,9 +1033,9 @@ class C(foo.A): w: int """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import dataclasses + import foo from typing import Any, Dict, List, Union - dataclasses: module - foo: module @dataclasses.dataclass class C(foo.A): w: int diff --git a/pytype/tests/test_decorators1.py b/pytype/tests/test_decorators1.py index c0e14e6f5..2daf6c5e9 100644 --- a/pytype/tests/test_decorators1.py +++ b/pytype/tests/test_decorators1.py @@ -268,8 +268,8 @@ def f(self, x=None): A().f(42) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Callable - foo = ... # type: module class A: f = ... # type: Callable """) diff --git a/pytype/tests/test_decorators2.py b/pytype/tests/test_decorators2.py index 7f72b79fe..d8a90ad62 100644 --- a/pytype/tests/test_decorators2.py +++ b/pytype/tests/test_decorators2.py @@ -48,7 +48,7 @@ class Bar: pass """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo def f(x: str) -> int: ... class Bar: ... """) @@ -110,7 +110,7 @@ class Bar: pass """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo def f() -> None: ... def g(x: int, y: int) -> int: ... class Foo: ... @@ -187,7 +187,7 @@ class Bar: pass """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo def f(x: float) -> str: ... def g(x: int, y: float) -> float: ... class Foo: ... diff --git a/pytype/tests/test_enums.py b/pytype/tests/test_enums.py index acb76c6e9..f74fcb208 100644 --- a/pytype/tests/test_enums.py +++ b/pytype/tests/test_enums.py @@ -39,7 +39,7 @@ class Colors(enum.Enum): BLUE = 3 """) self.assertTypesMatchPytd(ty, """ - enum: module + import enum class Colors(enum.Enum): BLUE: int GREEN: int @@ -64,7 +64,6 @@ class Colors(enum.Enum): def test_sunderscore_name_value(self): self.Check(""" - from typing import Any import enum class M(enum.Enum): A = 1 @@ -72,7 +71,7 @@ class M(enum.Enum): assert_type(M.A._value_, int) def f(m: M): assert_type(m._name_, str) - assert_type(m._value_, Any) + assert_type(m._value_, int) """) def test_sunderscore_name_value_pytd(self): @@ -95,7 +94,7 @@ def f(m: foo.M): def test_basic_enum_from_pyi(self): with file_utils.Tempdir() as d: d.create_file("e.pyi", """ - enum: module + import enum class Colors(enum.Enum): RED: int BLUE: int @@ -108,7 +107,7 @@ class Colors(enum.Enum): v = e.Colors.GREEN.value """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - e: module + import e c: e.Colors n: str v: int @@ -219,7 +218,7 @@ class M(enum.Enum): def test_name_lookup_pytd(self): with file_utils.Tempdir() as d: d.create_file("e.pyi", """ - enum: module + import enum a_string: str class M(enum.Enum): A: int @@ -268,7 +267,7 @@ class M(enum.Enum): def test_enum_pytd_named_name(self): with file_utils.Tempdir() as d: d.create_file("m.pyi", """ - enum: module + import enum class M(enum.Enum): name: int value: str @@ -306,7 +305,7 @@ class N(enum.Enum): def test_value_lookup_pytd(self): with file_utils.Tempdir() as d: d.create_file("m.pyi", """ - enum: module + import enum class M(enum.Enum): A: int class N(enum.Enum): @@ -383,7 +382,7 @@ class N(enum.Enum): def test_enum_pytd_eq(self): with file_utils.Tempdir() as d: d.create_file("m.pyi", """ - enum: module + import enum class M(enum.Enum): A: int class N(enum.Enum): @@ -438,7 +437,7 @@ class N(enum.Enum): def test_pytd_metaclass_methods(self): with file_utils.Tempdir() as d: d.create_file("m.pyi", """ - enum: module + import enum class M(enum.Enum): A: int """) @@ -990,6 +989,19 @@ def surface_gravity(self): assert_type(Planet.EARTH.surface_gravity, float) """) + def test_own_init_canonical(self): + self.Check(""" + import enum + + class Protocol(enum.Enum): + ssh = 22 + def __init__(self, port_number): + self.port_number = port_number + + def get_port(protocol: str) -> int: + return Protocol[protocol].port_number + """) + def test_own_init_errors(self): self.CheckWithErrors(""" import enum @@ -1002,10 +1014,9 @@ def __init__(self, a, b, c): def test_own_member_new(self): with file_utils.Tempdir() as d: d.create_file("foo.pyi", """ + import enum from typing import Annotated, Any, Type, TypeVar - enum: module - _TOrderedEnum = TypeVar('_TOrderedEnum', bound=OrderedEnum) class OrderedEnum(enum.Enum): @@ -1060,8 +1071,8 @@ def combo(self) -> str: return f"{self.str_v}+{self.value}" """) self.assertTypesMatchPytd(ty, """ + import enum from typing import Annotated - enum: module class M(enum.Enum): A: int combo: Annotated[str, 'property'] @@ -1119,6 +1130,23 @@ def take_m(m: M): return m.x """, pythonpath=[d.path]) + def test_instance_attrs_self_referential(self): + self.Check(""" + from dataclasses import dataclass + from enum import Enum + from typing import Optional + + @dataclass + class O: + thing: Optional["Thing"] = None + + class Thing(Enum): + A = O() + + def __init__(self, o: O): + self.b = o.thing + """) + def test_enum_bases(self): self.CheckWithErrors(""" import enum @@ -1141,8 +1169,8 @@ class M(enum.Enum): M.class_attr = 2 """) self.assertTypesMatchPytd(ty, """ + import enum from typing import ClassVar - enum: module class M(enum.Enum): A: int class_attr: ClassVar[int] diff --git a/pytype/tests/test_errors1.py b/pytype/tests/test_errors1.py index eae181329..b9d385190 100644 --- a/pytype/tests/test_errors1.py +++ b/pytype/tests/test_errors1.py @@ -612,8 +612,8 @@ def f(x: Type[A]) -> bool: ... error = ["Expected", "Type[a.A]", "Actual", "Type[a.C]"] self.assertErrorSequences(errors, {"e": error}) self.assertTypesMatchPytd(ty, """ + import a from typing import Any - a = ... # type: module x = ... # type: bool y = ... # type: bool z = ... # type: Any @@ -764,8 +764,8 @@ class B(A): ... x = v.x # No error because there is an Unsolvable in the MRO of a.A """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import Any - a = ... # type: module v = ... # type: a.A x = ... # type: Any """) diff --git a/pytype/tests/test_exceptions1.py b/pytype/tests/test_exceptions1.py index 4f731287c..58749c336 100644 --- a/pytype/tests/test_exceptions1.py +++ b/pytype/tests/test_exceptions1.py @@ -223,7 +223,7 @@ def warn(): DeprecationWarning, stacklevel=2) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - warnings = ... # type: module + import warnings def warn() -> None: ... """) diff --git a/pytype/tests/test_flax_overlay.py b/pytype/tests/test_flax_overlay.py index 324eb9c28..707e063c0 100644 --- a/pytype/tests/test_flax_overlay.py +++ b/pytype/tests/test_flax_overlay.py @@ -26,8 +26,8 @@ class Foo: z: str """, pythonpath=[d.path], module_name="foo") self.assertTypesMatchPytd(ty, """ + import flax from typing import Dict, TypeVar, Union - flax: module _TFoo = TypeVar('_TFoo', bound=Foo) @@ -50,8 +50,8 @@ def field(**kwargs): return dataclasses.field(**kwargs) """) self.assertTypesMatchPytd(ty, """ + import dataclasses from typing import Any - dataclasses: module def field(**kwargs) -> Any: ... """) @@ -92,12 +92,11 @@ class Foo(nn.Module): y: int = 10 """, pythonpath=[d.path], module_name="foo") self.assertTypesMatchPytd(ty, """ - import flax.linen.module + from flax import linen as nn from typing import Dict, TypeVar - nn: module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass - class Foo(flax.linen.module.Module): + class Foo(nn.module.Module): x: bool y: int __dataclass_fields__: Dict[str, dataclasses.Field] @@ -115,13 +114,11 @@ class Foo(module.Module): y: int = 10 """, pythonpath=[d.path], module_name="foo") self.assertTypesMatchPytd(ty, """ - import builtins - import flax.linen.module + from flax.linen import module from typing import Dict, TypeVar - module: builtins.module _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass - class Foo(flax.linen.module.Module): + class Foo(module.Module): x: bool y: int __dataclass_fields__: Dict[str, dataclasses.Field] @@ -162,12 +159,11 @@ class Foo(linen.Module): y: int = 10 """, pythonpath=[d.path], module_name="flax.linen.foo") self.assertTypesMatchPytd(ty, """ - import flax.linen.module - linen: module + from flax import linen from typing import Dict, TypeVar _TFoo = TypeVar('_TFoo', bound=Foo) @dataclasses.dataclass - class Foo(flax.linen.module.Module): + class Foo(linen.module.Module): x: bool y: int __dataclass_fields__: Dict[str, dataclasses.Field] @@ -227,8 +223,8 @@ class Bar(foo.Foo): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ import dataclasses + import foo from typing import Any, Dict, TypeVar - foo: module _TBar = TypeVar('_TBar', bound=Bar) @dataclasses.dataclass @@ -258,8 +254,8 @@ class Baz(Bar): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ import dataclasses + import foo from typing import Any, Dict, TypeVar - foo: module _TBar = TypeVar('_TBar', bound=Bar) @dataclasses.dataclass diff --git a/pytype/tests/test_flow1.py b/pytype/tests/test_flow1.py index 52bd25be7..b56491a08 100644 --- a/pytype/tests/test_flow1.py +++ b/pytype/tests/test_flow1.py @@ -330,8 +330,8 @@ def test_loop_over_list_of_lists(self): seq.append("foo") """, deep=False) self.assertTypesMatchPytd(ty, """ + import os from typing import List, Union - os = ... # type: module seq = ... # type: List[Union[int, str]] """) diff --git a/pytype/tests/test_functions1.py b/pytype/tests/test_functions1.py index 0c696f76f..89e267ebe 100644 --- a/pytype/tests/test_functions1.py +++ b/pytype/tests/test_functions1.py @@ -221,8 +221,8 @@ def f(x: T) -> T: ... x = foo.f() # missing-parameter """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo: module x: Any """) @@ -472,8 +472,8 @@ def f(x, y): return foo.f(x, y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def f(x, y) -> list: ... """) @@ -491,8 +491,8 @@ def f(x): return foo.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def f(x) -> Any: ... """) @@ -511,8 +511,8 @@ def f(y): return foo.f("", y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import List - foo = ... # type: module def f(y) -> List[str]: ... """) @@ -530,8 +530,8 @@ def f(x): return foo.f(x, "") """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Union - foo = ... # type: module # TODO(rechen): def f(x: str or List[str]) -> List[str]: ... def f(x) -> list: ... """) @@ -556,8 +556,8 @@ def h(x): return ret """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, List, MutableSequence - foo = ... # type: module # TODO(rechen): def f(x: unicode or List[unicode]) -> bool: ... def f(x) -> Any: ... def g(x) -> list: ... @@ -581,9 +581,7 @@ def compile() -> MyPattern[T]: ... x = foo.compile().match("") """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - import typing - - foo = ... # type: module + import foo x = ... # type: foo.MyMatch[str] """) @@ -598,7 +596,7 @@ def f(x: int, y: bool) -> int: ... x = foo.f(0, True) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo x = ... # type: int """) @@ -614,8 +612,8 @@ def f(x): return foo.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def f(x) -> Any: ... """) @@ -631,8 +629,8 @@ def f(x): return foo.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def f(x) -> Any: ... """) @@ -648,8 +646,8 @@ def f(x): return foo.f(y=x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def f(x) -> Any: ... """) @@ -702,8 +700,8 @@ def f(): pass w2 = type(f) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Callable, Tuple - foo = ... # type: module def f() -> None: ... v1 = ... # type: Tuple[Callable[[], None]] v2 = Callable @@ -723,8 +721,8 @@ def f(x: T) -> Tuple[Union[T, str], int]: ... v1, v2 = foo.f(42j) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module v1 = ... # type: Union[str, complex] v2 = ... # type: int """) @@ -792,7 +790,7 @@ def b(x: int, y: int, z: int): ... c = a.b """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def c(x: int, y: int, z: int = ...): ... """) @@ -933,7 +931,7 @@ def test_infer_bound_pytd_func(self): int2byte = chr """) self.assertTypesMatchPytd(ty, """ - struct = ... # type: module + import struct def int2byte(*v) -> bytes: ... """) @@ -949,8 +947,8 @@ def f(x: float) -> Union[int, str]: ... v = foo.f(__any_object__) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module v = ... # type: Union[int, str] """) @@ -970,7 +968,7 @@ def f(a, b): partial_f = functools.partial(f, 0) """) self.assertTypesMatchPytd(ty, """ - functools: module + import functools def f(a, b) -> None: ... partial_f: functools.partial """) @@ -1029,8 +1027,8 @@ def new_function(code, globals): return types.FunctionType(code, globals) """) self.assertTypesMatchPytd(ty, """ + import types from typing import Callable - types: module def new_function(code, globals) -> Callable: ... """) diff --git a/pytype/tests/test_generic1.py b/pytype/tests/test_generic1.py index a35d296f7..33661a503 100644 --- a/pytype/tests/test_generic1.py +++ b/pytype/tests/test_generic1.py @@ -21,7 +21,7 @@ def f(): return a.f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> a.A[int]: ... """) @@ -61,7 +61,7 @@ def bar(): return {list(x.keys())[0]: list(x.values())[0]} """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def foo() -> a.B: ... def bar() -> dict[str, int]: ... """) @@ -88,8 +88,8 @@ def bar(): return foo()[0] """, pythonpath=[d1.path, d2.path]) self.assertTypesMatchPytd(ty, """ + import b from typing import Union - b = ... # type: module def foo() -> b.B: ... def bar() -> Union[int, str]: ... """) @@ -115,7 +115,7 @@ def qux(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import List, Tuple - a = ... # type: module + import a def foo() -> a.A[nothing]: ... def bar() -> List[str]: ... def baz() -> a.B: ... @@ -137,7 +137,7 @@ def f(): return foo.B().bar() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def f() -> int: ... """) @@ -162,7 +162,7 @@ def baz(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Union - a = ... # type: module + import a def foo() -> a.A[nothing]: ... def bar() -> int: ... def baz() -> Union[int, str]: ... @@ -191,7 +191,7 @@ def f(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Union - a = ... # type: module + import a def f() -> Union[int, float, complex, str]: ... """) @@ -219,7 +219,7 @@ def h(self) -> Tuple[T2, T3]: ... """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any, Tuple - a = ... # type: module + import a v1 = ... # type: int v2 = ... # type: str v3 = ... # type: Tuple[int, str] @@ -246,7 +246,7 @@ def __init__(self): w = a.C().g() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a v = ... # type: str w = ... # type: str """) @@ -274,7 +274,7 @@ def g(self): """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Union - a = ... # type: module + import a # T1, T2, and T3 are all set to Any due to T1 being an alias for both # T2 and T3. v = ... # type: a.C[int, Union[float, int]] @@ -297,7 +297,7 @@ def f(): return foo.baz().bar() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def f() -> int: ... """) @@ -323,8 +323,8 @@ def g(): return b.A(3.14).a() """, pythonpath=[d1.path, d2.path]) self.assertTypesMatchPytd(ty, """ + import b from typing import Union - b = ... # type: module def f() -> b.A[Union[int, str]]: ... def g() -> Union[int, float]: ... """) @@ -346,7 +346,7 @@ class Custom(MyDict[K, V], MyList[V]): pass x = a.Custom() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: a.Custom[nothing, nothing] """) @@ -366,7 +366,7 @@ def f(): return x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> a.C[nothing]: ... """) @@ -387,7 +387,7 @@ def f(): self.assertTypesMatchPytd(ty, """ import a - a = ... # type: module + import a def f() -> a.A[int]: ... """) @@ -406,7 +406,7 @@ def g(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Union - a = ... # type: module + import a def f() -> a.A: ... def g() -> Union[int, str]: ... """) @@ -428,7 +428,7 @@ def f(): return x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> a.A[nothing, int]: ... """) @@ -455,7 +455,7 @@ def h(): return f()[0] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> a.A[str, int]: ... def g() -> str: ... def h() -> int: ... @@ -476,7 +476,7 @@ def f(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import List - a = ... # type: module + import a def f() -> List[nothing]: ... """) @@ -495,7 +495,7 @@ def f(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Union - a = ... # type: module + import a def f() -> Union[str, unicode]: ... """) @@ -540,7 +540,7 @@ def h(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - a = ... # type: module + import a # T was made unsolvable by an AliasingDictConflictError. def f() -> a.A[int, str]: ... def g() -> int: ... @@ -561,7 +561,7 @@ class A(Dict[T, U], List[T], Generic[T, U]): ... """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - a = ... # type: module + import a v = ... # type: a.A[nothing, nothing] """) @@ -579,7 +579,7 @@ def g(): return a.A()[0][0] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> a.A: ... def g() -> a.A: ... """) @@ -603,7 +603,7 @@ def bar(): return a.B()[0] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def foo() -> str: ... def bar() -> str: ... """) @@ -627,7 +627,7 @@ def bar(): return B()[0] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a class B(a.A): pass def foo() -> str: ... def bar() -> str: ... @@ -653,8 +653,8 @@ def g(): return (v.x, v.y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import Tuple, Union - a: module def f() -> Tuple[Union[int, float], complex]: ... def g() -> Tuple[int, float]: ... """) @@ -679,7 +679,7 @@ def g(): return RE.pattern """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a RE = ... # type: a.MyPattern[str] def f(x) -> None: ... def g() -> str: ... @@ -718,7 +718,7 @@ def h(): return x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> str: ... def g() -> bool: ... def h() -> int: ... @@ -744,7 +744,7 @@ def g(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any, TypeVar, Union - a = ... # type: module + import a T = TypeVar("T") class B(a.A[T]): x = ... # type: Union[int, float] @@ -769,7 +769,7 @@ def f(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - a = ... # type: module + import a def f() -> Any: ... """) @@ -797,7 +797,7 @@ def g(x): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Optional, Union - a = ... # type: module + import a def f(x) -> Union[int, float]: ... def g(x) -> Optional[float]: ... """) @@ -816,7 +816,7 @@ def f(): return abs(a.A([42]).x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f() -> int: ... """) @@ -836,7 +836,7 @@ def f(x): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - a = ... # type: module + import a def f(x) -> Any: ... """) @@ -854,7 +854,7 @@ def f(): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - a = ... # type: module + import a def f() -> Any: ... """) @@ -870,7 +870,7 @@ class A(List[Any]): pass n = len(a.A()[0]) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a n = ... # type: int """) @@ -886,7 +886,7 @@ def f(x: Iterable[Q]) -> Q: ... x = a.f({True: "false"}) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: bool """) @@ -905,8 +905,8 @@ def __init__(self): v = list(foo.Foo()) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module v = ... # type: list[Union[int, str]] """) @@ -926,7 +926,7 @@ def foo(self): return self.data """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a class B(a.A): data = ... # type: list def foo(self) -> list: ... @@ -949,7 +949,7 @@ def foo(self): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import List, Union - a = ... # type: module + import a class B(a.A): data = ... # type: List[Union[int, str]] def foo(self) -> List[Union[int, str]]: ... @@ -972,7 +972,7 @@ def foo(self): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import List - a = ... # type: module + import a class B(a.A): data = ... # type: List[complex] def foo(self) -> List[complex]: ... @@ -992,8 +992,8 @@ def make_A() -> A: ... v = foo.make_A().v """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module v = ... # type: Union[int, float] """) @@ -1011,7 +1011,7 @@ def make_A() -> A: ... v = foo.make_A().v """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo v = ... # type: float """) @@ -1033,7 +1033,7 @@ def to_int(self): a.to_int() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo a = ... # type: foo.A[int] """) @@ -1053,7 +1053,7 @@ def __init__(self): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - a = ... # type: module + import a class Derived(a.Base): def __init__(self) -> None: ... """) diff --git a/pytype/tests/test_generic2.py b/pytype/tests/test_generic2.py index 25fe7970b..6e9ef0857 100644 --- a/pytype/tests/test_generic2.py +++ b/pytype/tests/test_generic2.py @@ -325,7 +325,6 @@ def put(self, elem: T): ... self.assertTypesMatchPytd(ty, """ import a - a = ... # type: module b = ... # type: a.A[int] """) @@ -404,18 +403,23 @@ def test_type_renaming_error(self): W = TypeVar('W') class A(Generic[T]): pass - class B(A[V]): pass # not-supported-yet[e1] + class B(A[V]): pass # bad-concrete-type[e1] class C(Generic[V]): pass class D(C[T]): pass - class E(D[S]): pass # not-supported-yet[e2] + class E(D[S]): pass # bad-concrete-type[e2] class F(Generic[U]): pass - class G(F[W]): pass # not-supported-yet[e3] + class G(F[W]): pass # bad-concrete-type[e3] """) - self.assertErrorRegexes(errors, {"e1": r"Renaming TypeVar `T`.*", - "e2": r"Renaming TypeVar `T`.*", - "e3": r"Renaming TypeVar `U`.*"}) + self.assertErrorSequences(errors, { + "e1": ["Expected: T", "Actually passed: V", + "T and V have incompatible"], + "e2": ["Expected: T", "Actually passed: S", + "T and S have incompatible"], + "e3": ["Expected: U", "Actually passed: W", + "U and W have incompatible"], + }) def test_type_parameter_conflict_error(self): ty, errors = self.InferWithErrors(""" @@ -806,6 +810,23 @@ class Bar(Foo[int]): x: int """) + def test_rename_bounded_typevar(self): + self.CheckWithErrors(""" + from typing import Callable, Generic, TypeVar + + T = TypeVar('T', bound=int) + No = TypeVar('No', bound=float) + Ok = TypeVar('Ok', bound=bool) + + class Box(Generic[T]): + def __init__(self, x: T): + self.x = x + def error(self, f: Callable[[T], No]) -> 'Box[No]': # bad-concrete-type + return Box(f(self.x)) # wrong-arg-types + def good(self, f: Callable[[T], Ok]) -> 'Box[Ok]': + return Box(f(self.x)) + """) + class GenericFeatureTest(test_base.BaseTest): """Tests for User-defined Generic Type.""" @@ -830,7 +851,6 @@ def f(): self.assertTypesMatchPytd(ty, """ import a - a = ... # type: module d = ... # type: a.A[int] ks = ... # type: dict_keys[int] vs = ... # type: dict_values[int] @@ -929,7 +949,7 @@ def __init__(self, x: T): x2 = foo.Foo[str](__any_object__).x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo x1: int x2: str """) diff --git a/pytype/tests/test_import1.py b/pytype/tests/test_import1.py index 642d74861..5a275dabe 100644 --- a/pytype/tests/test_import1.py +++ b/pytype/tests/test_import1.py @@ -20,7 +20,7 @@ def test_basic_import(self): import sys """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys """) def test_basic_import2(self): @@ -48,7 +48,7 @@ def foo(): return my_module.foo() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - my_module = ... # type: module + from path.to import my_module def foo() -> str: ... """) @@ -135,7 +135,7 @@ def foo(): return path.to.my_module.qqsv() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - path = ... # type: module + import path def foo() -> str: ... """) @@ -177,7 +177,7 @@ def f(): return sys """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys def f() -> module: ... """) @@ -191,7 +191,7 @@ def f(): return sys """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys def f() -> module: ... """) @@ -203,7 +203,7 @@ def f(): """) self.assertTypesMatchPytd(ty, """ from typing import List - sys = ... # type: module + import sys def f() -> List[str, ...]: ... """) @@ -226,7 +226,7 @@ def f(): return datetime.timedelta().total_seconds() """) self.assertTypesMatchPytd(ty, """ - datetime = ... # type: module + import datetime def f() -> float: ... """) @@ -278,8 +278,8 @@ def j(): """) ty = self.InferFromFile(filename=d["main.py"], pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - other_file = ... # type: module - sub = ... # type: module # from 'import sub.bar.baz' + import sub # from 'import sub.bar.baz' + from sub import other_file def g() -> float: ... def h() -> int: ... def i() -> float: ... @@ -349,7 +349,7 @@ def f(): d.create_file("foo/__init__.pyi", "") ty = self.InferFromFile(filename=d["foo/bar.py"], pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - baz = ... # type: module + from foo import baz def f() -> int: ... """) @@ -418,7 +418,7 @@ def f(): ty = self.InferFromFile(filename=d["foo/deep/bar.py"], pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - baz = ... # type: module + from foo import baz def f() -> int: ... """) @@ -452,7 +452,7 @@ def test_dot_dot_in_pyi(self): ty = self.InferFromFile(filename=d["foo/deep/bar.py"], pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - baz = ... # type: module + from foo import baz a: int """) @@ -479,8 +479,7 @@ def test_from_dot_in_pyi(self): ty = self.InferFromFile(filename=d["top.py"], pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Type - import foo.a - foo = ... # type: module + import foo x = ... # type: foo.a.X """) @@ -505,7 +504,7 @@ def my_foo(x): return path.to.some.module.foo(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - path = ... # type: module + import path def my_foo(x) -> str: ... """) @@ -522,7 +521,7 @@ def my_foo(x): return module.foo(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - module = ... # type: builtins.module + from path.to.some import module def my_foo(x) -> str: ... """) @@ -548,7 +547,7 @@ def f(): return __builtin__.int() """) self.assertTypesMatchPytd(ty, """ - __builtin__: module + import builtins as __builtin__ def f() -> int: ... """) @@ -560,7 +559,7 @@ class Foo: killpg = os.killpg """) self.assertTypesMatchPytd(ty, """ - os = ... # type: module + import os class Foo: def killpg(__pgid: int, __signal: int) -> None: ... """) @@ -588,7 +587,7 @@ def h(x): """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - foo = ... # type: module + import foo def f(x, y) -> Any: ... def g(x) -> Any: ... def h(x) -> Any: ... @@ -614,7 +613,7 @@ def h(): return module.Foo.x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - module = ... # type: builtins.module + import module def f() -> int: ... def g() -> float: ... def h() -> float: ... @@ -645,7 +644,7 @@ class Z: zz = x.z """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - x = ... # type: module + import x xx = ... # type: x.X yy = ... # type: y.Y zz = ... # type: z.Z @@ -661,7 +660,7 @@ def test_reimport(self): d = foo.MyOrderedDict() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo d = ... # type: collections.OrderedDict[nothing, nothing] """) @@ -677,7 +676,7 @@ def test_import_function(self): self.assertTypesMatchPytd(ty, """ from typing import Union from typing import SupportsFloat - foo = ... # type: module + import foo def d(__x: SupportsFloat, __y: SupportsFloat) -> float: ... """) @@ -692,7 +691,7 @@ def test_import_constant(self): y = mymath.half_tau """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - mymath = ... # type: module + import mymath x = ... # type: float y = ... # type: float """) @@ -728,8 +727,8 @@ def f(x: Foo) -> Foo: ... bar = b.f(foo) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import b from typing import Any - b = ... # type: module foo = ... # type: Any bar = ... # type: Any """) @@ -765,9 +764,9 @@ class Y: pass""") b = pkg.Y() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import pkg a = ... # type: pkg.pkg.pkg.X b = ... # type: pkg.bar.Y - pkg = ... # type: module """) def test_redefined_builtin(self): @@ -783,7 +782,7 @@ def f(x) -> Any: ... """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - foo = ... # type: module + import foo x = ... # type: Any """) @@ -801,7 +800,7 @@ def f(x: object) -> object: ... foo.f(object()) # wrong-arg-types """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo x = ... # type: foo.object y = ... # type: foo.object """) @@ -826,7 +825,7 @@ def factory() -> type: ... A = a.factory() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a A = ... # type: type """) @@ -918,8 +917,8 @@ class X: init_pyi = self.Infer( init_py % "", imports_map=imports_map, module_name="mod.__init__") self.assertTypesMatchPytd(init_pyi, """ + from mod import submod from typing import Type - submod: module X: Type[mod.submod.X] """) @@ -965,8 +964,9 @@ def __init__(self, x): init_pyi = self.Infer(init_py % "", imports_map=imports_map, module_name="mod.__init__") self.assertTypesMatchPytd(init_pyi, """ + import mod.submod + import typing from typing import Type - typing: module X: Type[mod.submod.X] class Y: def __init__(self, x: X) -> None: ... @@ -1025,14 +1025,11 @@ class Z: ... v = u.Y() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - import pkg.b - import pkg.d - import pkg.g - pkg = ... # type: module + import pkg + from pkg import d as u s = ... # type: pkg.b.X t = ... # type: pkg.b.e - u = ... # type: module - v = ... # type: pkg.d.Y + v = ... # type: u.Y """) def test_import_package_as_alias(self): @@ -1063,7 +1060,7 @@ class a: y = b.y """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - b: module + import b x: str y: int """) @@ -1084,7 +1081,7 @@ def test_import_package_alias_name_conflict2(self): y = c.y """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - c: module + import c x: str y: int """) @@ -1105,7 +1102,7 @@ def test_import_package_alias_name_conflict3(self): y = c.y """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - c: module + import c x: str y: int """) @@ -1120,8 +1117,8 @@ def __new__(cls): return object.__new__(cls) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + from foo import bar from typing import Type, TypeVar - bar = ... # type: module _Tfoo = TypeVar("_Tfoo", bound=foo) class foo: def __new__(cls: Type[_Tfoo]) -> _Tfoo: ... @@ -1137,7 +1134,7 @@ class foo: baz = foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - bar = ... # type: module + from foo import bar class foo: ... baz = foo """) @@ -1216,7 +1213,7 @@ def test_subpackage(self): v = foo.baz.v """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo v: str """) @@ -1236,7 +1233,7 @@ def test_attr_and_module(self): self.assertTypesMatchPytd(ty, """ from typing import Type import foo - other: module + import other X: Type[foo.X] v: str """) diff --git a/pytype/tests/test_import2.py b/pytype/tests/test_import2.py index 3d7df6db2..70583ecca 100644 --- a/pytype/tests/test_import2.py +++ b/pytype/tests/test_import2.py @@ -16,8 +16,8 @@ def test_module_attributes(self): p = os.__package__ """) self.assertTypesMatchPytd(ty, """ + import os from typing import Optional - os = ... # type: module f = ... # type: str n = ... # type: str d = ... # type: str @@ -36,9 +36,9 @@ def h(): return sys.getrecursionlimit() """, report_errors=False) self.assertTypesMatchPytd(ty, """ + import sys from typing import Any, TextIO bad_import = ... # type: Any - sys = ... # type: module def f() -> TextIO: ... def g() -> int: ... def h() -> int: ... @@ -53,7 +53,7 @@ def test_relative_priority(self): x = a.x """, deep=False, pythonpath=[d.path], module_name="b.main") self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: int """) diff --git a/pytype/tests/test_match1.py b/pytype/tests/test_match1.py index 80ad353b1..98811fc76 100644 --- a/pytype/tests/test_match1.py +++ b/pytype/tests/test_match1.py @@ -19,7 +19,7 @@ def f(): return foo.f(int) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def f() -> str: ... """) @@ -48,7 +48,7 @@ def f(x: Iterable[str]) -> str: ... x = a.f(["a", "b", "c"]) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: str """) @@ -70,7 +70,7 @@ def f(x: Iterable[Q]) -> Q: ... x = a.f(a.B()) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: str """) @@ -87,7 +87,7 @@ def f(x: T) -> T: ... """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any - foo = ... # type: module + import foo v = ... # type: Any """) diff --git a/pytype/tests/test_match2.py b/pytype/tests/test_match2.py index 76e21fe65..13dde1bf4 100644 --- a/pytype/tests/test_match2.py +++ b/pytype/tests/test_match2.py @@ -124,7 +124,7 @@ def g4(x: int): pass """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import Any, List - foo = ... # type: module + import foo def g1() -> Any: ... def g2() -> int: ... def g3(x) -> Any: ... @@ -189,7 +189,7 @@ def f() -> Generator[Optional[str], None, None]: """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import List, Optional - foo = ... # type: module + import foo f = ... # type: List[Optional[str]] """) @@ -398,7 +398,7 @@ def f(): """) self.assertTypesMatchPytd(ty, """ from typing import Generator - tokenize = ... # type: module + import tokenize def f() -> NoneType: ... x = ... # type: Generator[tokenize.TokenInfo, None, None] """) @@ -438,8 +438,8 @@ def test_bound_against_callable(self): """) self.assertTypesMatchPytd(ty, """ from typing import Generator - io = ... # type: module - tokenize = ... # type: module + import io + import tokenize x = ... # type: Generator[tokenize.TokenInfo, None, None] """) diff --git a/pytype/tests/test_methods1.py b/pytype/tests/test_methods1.py index efb1f8060..2d72a39bd 100644 --- a/pytype/tests/test_methods1.py +++ b/pytype/tests/test_methods1.py @@ -603,8 +603,8 @@ def f(*args): return myjson.loads(*args) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import myjson from typing import Any - myjson = ... # type: module def f(*args) -> Any: ... """) @@ -620,8 +620,8 @@ def f(**args): return myjson.loads(**args) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import myjson from typing import Any - myjson = ... # type: module def f(**args) -> Any: ... """) @@ -637,8 +637,8 @@ def f(): return myjson.loads(s="{}") """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import myjson from typing import Any - myjson = ... # type: module def f() -> Any: ... """) @@ -877,7 +877,7 @@ def method(self): d = os.chmod.x """) self.assertTypesMatchPytd(ty, """ - os = ... # type: module + import os def f() -> NoneType: ... class Foo: def method(self) -> NoneType: ... @@ -893,7 +893,7 @@ def test_json(self): import json """, deep=False) self.assertTypesMatchPytd(ty, """ - json = ... # type: module + import json """) def test_new(self): diff --git a/pytype/tests/test_namedtuple1.py b/pytype/tests/test_namedtuple1.py index 6dde92597..3fa113a27 100644 --- a/pytype/tests/test_namedtuple1.py +++ b/pytype/tests/test_namedtuple1.py @@ -34,9 +34,7 @@ def _namedtuple_def(self, suffix="", **kws): """ (alias, (name, fields)), = kws.items() # pylint: disable=unbalanced-tuple-unpacking name = escape.pack_namedtuple(name, fields) - suffix += textwrap.dedent(""" - collections = ... # type: module - {alias} = {name}""").format(alias=alias, name=name) + suffix += f"\n{alias} = {name}" return pytd_utils.Print(self._namedtuple_ast(name, fields)) + "\n" + suffix def test_basic_namedtuple(self): @@ -325,7 +323,6 @@ def test_name_conflict(self): ast_z = self._namedtuple_ast(name_z, ["a"]) ast = pytd_utils.Concat(ast_x, ast_z) expected = pytd_utils.Print(ast) + textwrap.dedent(""" - collections = ... # type: module X = {name_x} Y = {name_x} Z = {name_z}""").format(name_x=name_x, name_z=name_z) @@ -341,7 +338,6 @@ def __new__(cls, _): name = escape.pack_namedtuple("X", []) ast = self._namedtuple_ast(name, []) expected = pytd_utils.Print(ast) + textwrap.dedent(""" - collections = ... # type: module _TX = TypeVar("_TX", bound=X) class X({name}): def __new__(cls: Type[_TX], _) -> _TX: ...""").format(name=name) diff --git a/pytype/tests/test_operators1.py b/pytype/tests/test_operators1.py index 5273b1f98..50728d8b9 100644 --- a/pytype/tests/test_operators1.py +++ b/pytype/tests/test_operators1.py @@ -271,8 +271,8 @@ def g(t): return (1, 2) | t """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import test from typing import Any - test = ... # type: module x = ... # type: bool y = ... # type: bool def f(t) -> Any: ... diff --git a/pytype/tests/test_overload.py b/pytype/tests/test_overload.py index 1a65dda31..f999c1e92 100644 --- a/pytype/tests/test_overload.py +++ b/pytype/tests/test_overload.py @@ -224,8 +224,8 @@ def f4(): return foo.f(**{"x": "y"}) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo: module def f1(*args) -> Any: ... def f2(**kwargs) -> Any: ... def f3() -> int: ... diff --git a/pytype/tests/test_pickle1.py b/pytype/tests/test_pickle1.py index 187c00066..21966b151 100644 --- a/pytype/tests/test_pickle1.py +++ b/pytype/tests/test_pickle1.py @@ -41,9 +41,8 @@ def test_type(self): r = u.x """, deep=False, pythonpath=[""], imports_map={"u": u}) self.assertTypesMatchPytd(ty, """ + import u from typing import Type - import collections - u = ... # type: module r = ... # type: Type[type] """) @@ -67,10 +66,10 @@ def test_copy_class_into_output(self): import bar r = bar.file_dispatcher(0) """, deep=False, pythonpath=[""], imports_map={"foo": foo, "bar": bar}) - self._verifyDeps(ty, ["asyncore", "builtins"], []) + self._verifyDeps(ty, ["asyncore"], []) self.assertTypesMatchPytd(ty, """ import asyncore - bar = ... # type: module + import bar r = ... # type: asyncore.file_dispatcher """) @@ -89,7 +88,7 @@ def f(): imports_map={"foo": foo}, module_name="bar", deep=True) bar = d.create_file("bar.pickled", pickled_bar) - self._verifyDeps(pickled_bar, ["builtins"], ["foo"]) + self._verifyDeps(pickled_bar, [], ["foo"]) self.Infer(""" import bar f = bar.f @@ -110,7 +109,7 @@ class X: pass class A(foo.X): pass class B(foo.Y): pass """, deep=False, pickle=True, imports_map={"foo": foo}, module_name="bar") - self._verifyDeps(pickled_bar, ["builtins"], ["foo"]) + self._verifyDeps(pickled_bar, [], ["foo"]) bar = d.create_file("bar.pickled", pickled_bar) # Now, replace the old foo.pickled with a version that doesn't have Y # anymore. diff --git a/pytype/tests/test_pickle2.py b/pytype/tests/test_pickle2.py index 2e7645efd..66ae035df 100644 --- a/pytype/tests/test_pickle2.py +++ b/pytype/tests/test_pickle2.py @@ -23,7 +23,7 @@ def g() -> json.JSONDecoder: """, deep=False, pythonpath=[""], imports_map={"u": u}) self.assertTypesMatchPytd(ty, """ import collections - u = ... # type: module + import u r = ... # type: collections.OrderedDict[int, int] """) diff --git a/pytype/tests/test_protocol_inference.py b/pytype/tests/test_protocol_inference.py index 7e914fe99..665891f49 100644 --- a/pytype/tests/test_protocol_inference.py +++ b/pytype/tests/test_protocol_inference.py @@ -22,8 +22,8 @@ def f(x, y): return foo.f(x, y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module def f(x, y: Union[int, str]) -> list: ... """) @@ -43,8 +43,8 @@ def f(y): return foo.f("", y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import List - foo = ... # type: module def f(y: int) -> List[str]: ... """) @@ -61,8 +61,8 @@ def f(x): return foo.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module def f(x: Union[int, str]) -> Union[float, bool]: ... """) @@ -79,8 +79,8 @@ def f(x): return foo.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module def f(x: str) -> Union[int, float]: ... """) @@ -97,8 +97,8 @@ def f(x): return foo.f(y=x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module def f(x: Union[int, str]) -> Union[bool, float]: ... """) @@ -324,8 +324,8 @@ def g(y): return foo.f(y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Sized, SupportsAbs - foo = ... # type: module def g(y: SupportsAbs[Sized]) -> None: ... """) @@ -342,8 +342,8 @@ def g(y): return foo.f(y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import SupportsAbs - foo = ... # type: module def g(y: SupportsAbs[int]) -> None: ... """) diff --git a/pytype/tests/test_protocols1.py b/pytype/tests/test_protocols1.py index e2f03b08d..fd19386a1 100644 --- a/pytype/tests/test_protocols1.py +++ b/pytype/tests/test_protocols1.py @@ -74,7 +74,7 @@ class Baz(foo.Foo[T]): pass """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo from typing import Generic, Protocol, TypeVar T = TypeVar('T') class Baz(Protocol, Generic[T]): ... diff --git a/pytype/tests/test_pyi1.py b/pytype/tests/test_pyi1.py index 0a04d8fd0..c7a45e563 100644 --- a/pytype/tests/test_pyi1.py +++ b/pytype/tests/test_pyi1.py @@ -34,7 +34,7 @@ def g(): return mod.f(3) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - mod = ... # type: module + import mod def f() -> NoneType: ... def g() -> NoneType: ... """) @@ -50,7 +50,7 @@ def g(x): return mod.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - mod = ... # type: module + import mod def g(x) -> str: ... """) @@ -66,8 +66,8 @@ def g(x): return mod.split(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import mod from typing import List - mod = ... # type: module def g(x) -> List[str, ...]: ... """) @@ -84,7 +84,7 @@ class B(A): x = classes.B().foo() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - classes = ... # type: module + import classes x = ... # type: classes.A """) @@ -99,8 +99,8 @@ def __getattr__(name) -> Any: ... x = vague.foo + vague.bar """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import vague from typing import Any - vague = ... # type: module x = ... # type: Any """) @@ -124,7 +124,7 @@ def w(self, a, b) -> int: ... z = a.w(1, 2) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - decorated = ... # type: module + import decorated a = ... # type: decorated.A u = ... # type: int v = ... # type: int @@ -146,7 +146,7 @@ def w(self, x: classmethod) -> int: ... u = a.A().w(a.A.v) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a u = ... # type: int """) @@ -160,7 +160,7 @@ def parse(source, filename = ..., mode = ..., *args, **kwargs) -> int: ... u = a.parse("True") """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a u = ... # type: int """) @@ -179,8 +179,8 @@ def g(): out = out.split() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import Any - a = ... # type: module def f(foo, bar) -> Any: ... def g() -> NoneType: ... """) @@ -196,7 +196,7 @@ def f(l: Iterable[int]) -> int: ... u = a.f([1, 2, 3]) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a u = ... # type: int """) @@ -214,7 +214,7 @@ def f(x=None): return True """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f(x=...) -> bool: ... """) @@ -232,8 +232,8 @@ def bar(): x = foo.process_function(bar) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def bar() -> Any: ... # 'Any' because deep=False x = ... # type: NoneType """) @@ -269,8 +269,8 @@ def h(x): return x.baz() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def f(x) -> Any: ... def g(x) -> Any: ... def h(x) -> Any: ... @@ -289,8 +289,8 @@ def g(): return foo.f(foo.Foo()) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module def g() -> Any: ... """) @@ -306,7 +306,7 @@ def f(x: T) -> T: ... x = foo.f(3) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo x = ... # type: int """) @@ -327,7 +327,7 @@ def f(x: T) -> T: ... x = bar.f("") """, pythonpath=[d1.path, d2.path]) self.assertTypesMatchPytd(ty, """ - bar = ... # type: module + import bar x = ... # type: str """) @@ -358,7 +358,7 @@ def f(x: int) -> str: ... x = a.f(a.lst[0]) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: str """) @@ -373,7 +373,7 @@ def foo(x: str, *y: Any, z: complex = ...) -> int: ... x = a.foo("foo %d %d", 3, 3) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: int """) @@ -389,7 +389,7 @@ def get_pos(x: T, *args: int, z: int, **kws: int) -> T: ... v = a.get_pos("foo", 3, 4, z=5) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a v = ... # type: str """) @@ -405,7 +405,7 @@ def get_kwonly(x: int, *args: int, z: T, **kwargs: int) -> T: ... v = a.get_kwonly(3, 4, z=5j) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a v = ... # type: complex """) @@ -425,8 +425,8 @@ def foo(a: K, *b, c: V, **d) -> Dict[K, V]: ... d = foo.foo(*(), **{"d": 3j}) # missing-parameter[e3] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any, Dict - foo = ... # type: module a = ... # type: Any b = ... # type: Dict[int, complex] c = ... # type: Any @@ -453,7 +453,7 @@ def f(x): return a.A3() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a def f(x) -> a.A1: ... """) @@ -468,7 +468,7 @@ def test_builtins_module(self): x = a.x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: int """) @@ -486,7 +486,7 @@ def test_frozenset(self): """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ from typing import FrozenSet - a = ... # type: module + import a x = ... # type: FrozenSet[str] y = ... # type: FrozenSet[str] """) @@ -529,7 +529,7 @@ class Foo(MySupportsAbs[float], MyContextManager[Foo]): ... v = foo.Foo().__enter__() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo v = ... # type: foo.Foo """) @@ -549,7 +549,7 @@ def bar(self, x:T2): x.bar(10) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo x = ... # type: foo.Bar[int] """) @@ -610,7 +610,7 @@ class Foo: pass v2 = bar.f("") """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - bar = ... # type: module + import bar v1 = ... # type: foo.Foo v2 = ... # type: str """) @@ -719,8 +719,8 @@ def t(a: str) -> None: ... ta = foo.A.t """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Callable - foo = ... # type: module ta = ... # type: Callable[[str], None] """) @@ -736,7 +736,7 @@ class Foo: Const = foo.Const """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo Const = ... # type: int """) @@ -752,7 +752,7 @@ def f(self) -> int: ... Func = foo.Func """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def Func(self) -> int: ... """) @@ -773,7 +773,7 @@ def f(self) -> int: ... Func = foo.Func """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo Const = ... # type: int def Func(self) -> int: ... """) @@ -793,7 +793,7 @@ class Bar(Foo, MutableSequence[Bar]): ... v = foo.Bar()[0] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo v = ... # type: foo.Bar """) @@ -809,7 +809,7 @@ def test_dot_import(self): a = b.X() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - b = ... # type: module + from foo import b a = ... # type: foo.a.A """) @@ -825,7 +825,7 @@ def test_dot_dot_import(self): a = b.X() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - b = ... # type: module + from foo.bar import b a = ... # type: foo.a.A """) @@ -848,8 +848,8 @@ def f(x: tuple[int]) -> tuple[int, int]: ... x = foo.f((0,)) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Tuple - foo: module x: Tuple[int, int] """) @@ -866,7 +866,7 @@ def __init__(self, x: T) -> None: ... x = foo.Foo(x=0) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo x: foo.Foo[int] """) diff --git a/pytype/tests/test_pyi2.py b/pytype/tests/test_pyi2.py index 230205df0..897c64d7a 100644 --- a/pytype/tests/test_pyi2.py +++ b/pytype/tests/test_pyi2.py @@ -14,7 +14,7 @@ def foo(**kwargs: typing.Any) -> int: return 1 def bar(*args: typing.Any) -> int: return 2 """) self.assertTypesMatchPytd(ty, """ - typing = ... # type: module + import typing def foo(**kwargs) -> int: ... def bar(*args) -> int: ... """) @@ -74,7 +74,7 @@ def f() -> bytes: ... x = foo.f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo x = ... # type: bytes """) @@ -91,8 +91,8 @@ def f() -> Callable[[], float]: ... x = func() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Callable - foo: module func: Callable[[], float] x: float """) diff --git a/pytype/tests/test_recovery2.py b/pytype/tests/test_recovery2.py index f68f64497..f78c3bf0e 100644 --- a/pytype/tests/test_recovery2.py +++ b/pytype/tests/test_recovery2.py @@ -59,8 +59,8 @@ def g(): return '%s' % f() """, report_errors=False) self.assertTypesMatchPytd(ty, """ + import time from typing import Any - time = ... # type: module def f() -> Any: ... def g() -> str: ... """) diff --git a/pytype/tests/test_reingest1.py b/pytype/tests/test_reingest1.py index b91f9d0a1..628f44ec9 100644 --- a/pytype/tests/test_reingest1.py +++ b/pytype/tests/test_reingest1.py @@ -52,7 +52,7 @@ def g(): return f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def f() -> int: ... def g() -> int: ... """) @@ -74,7 +74,7 @@ def g(): return f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo def f() -> int: ... def g() -> int: ... """) @@ -179,7 +179,7 @@ class MyList(list): """, pythonpath=[d.path]) # MyList is not parameterized because it inherits from List[Any]. self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo lst = ... # type: foo.MyList """) @@ -199,7 +199,7 @@ class MyList(List[T]): lst.write(42) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo lst = ... # type: foo.MyList[int] """) @@ -218,7 +218,7 @@ def __init__(self): x = foo.Foo[int]() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo x: foo.Foo[int] """) diff --git a/pytype/tests/test_reingest2.py b/pytype/tests/test_reingest2.py index fa75980f9..8b81147d8 100644 --- a/pytype/tests/test_reingest2.py +++ b/pytype/tests/test_reingest2.py @@ -1,7 +1,5 @@ """Tests for reloading generated pyi.""" -from pytype import file_utils -from pytype.pytd import pytd_utils from pytype.tests import test_base @@ -9,57 +7,54 @@ class ReingestTest(test_base.BaseTest): """Tests for reloading the pyi we generate.""" def test_type_parameter_bound(self): - foo = self.Infer(""" + foo = """ from typing import TypeVar T = TypeVar("T", bound=float) def f(x: T) -> T: return x - """, deep=False) - with file_utils.Tempdir() as d: - d.create_file("foo.pyi", pytd_utils.Print(foo)) + """ + with self.DepTree([("foo.py", foo, dict(deep=False))]): _, errors = self.InferWithErrors(""" import foo foo.f("") # wrong-arg-types[e] - """, pythonpath=[d.path]) + """) self.assertErrorRegexes(errors, {"e": r"float.*str"}) def test_default_argument_type(self): - foo = self.Infer(""" + foo = """ from typing import Any, Callable, TypeVar T = TypeVar("T") def f(x): return True def g(x: Callable[[T], Any]) -> T: ... - """) - with file_utils.Tempdir() as d: - d.create_file("foo.pyi", pytd_utils.Print(foo)) + """ + with self.DepTree([("foo.py", foo)]): self.Check(""" import foo foo.g(foo.f).upper() - """, pythonpath=[d.path]) + """) def test_duplicate_anystr_import(self): - dep1 = self.Infer(""" + dep1 = """ from typing import AnyStr def f(x: AnyStr) -> AnyStr: return x - """) - with file_utils.Tempdir() as d: - d.create_file("dep1.pyi", pytd_utils.Print(dep1)) - dep2 = self.Infer(""" - from typing import AnyStr - from dep1 import f - def g(x: AnyStr) -> AnyStr: - return x - """, pythonpath=[d.path]) - d.create_file("dep2.pyi", pytd_utils.Print(dep2)) - self.Check("import dep2", pythonpath=[d.path]) + """ + dep2 = """ + from typing import AnyStr + from dep1 import f + def g(x: AnyStr) -> AnyStr: + return x + """ + deps = [("dep1.py", dep1), ("dep2.py", dep2)] + with self.DepTree(deps): + self.Check("import dep2") class ReingestTestPy3(test_base.BaseTest): """Python 3 tests for reloading the pyi we generate.""" def test_instantiate_pyi_class(self): - foo = self.Infer(""" + foo = """ import abc class Foo(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -68,30 +63,28 @@ def foo(self): class Bar(Foo): def foo(self): pass - """) - with file_utils.Tempdir() as d: - d.create_file("foo.pyi", pytd_utils.Print(foo)) + """ + with self.DepTree([("foo.py", foo)]): _, errors = self.InferWithErrors(""" import foo foo.Foo() # not-instantiable[e] foo.Bar() - """, pythonpath=[d.path]) + """) self.assertErrorRegexes(errors, {"e": r"foo\.Foo.*foo"}) def test_use_class_attribute_from_annotated_new(self): - foo = self.Infer(""" + foo = """ class Foo: def __new__(cls) -> "Foo": return cls() class Bar: FOO = Foo() - """) - with file_utils.Tempdir() as d: - d.create_file("foo.pyi", pytd_utils.Print(foo)) + """ + with self.DepTree([("foo.py", foo)]): self.Check(""" import foo print(foo.Bar.FOO) - """, pythonpath=[d.path]) + """) if __name__ == "__main__": diff --git a/pytype/tests/test_six_overlay1.py b/pytype/tests/test_six_overlay1.py index 4059e8905..654d9e594 100644 --- a/pytype/tests/test_six_overlay1.py +++ b/pytype/tests/test_six_overlay1.py @@ -76,7 +76,7 @@ class Bar: x2 = Bar().x """) self.assertTypesMatchPytd(ty, """ - six: module + import six class Foo(type): x: int def __init__(self, *args) -> None: ... diff --git a/pytype/tests/test_six_overlay2.py b/pytype/tests/test_six_overlay2.py index 9f038e4fe..328002e0e 100644 --- a/pytype/tests/test_six_overlay2.py +++ b/pytype/tests/test_six_overlay2.py @@ -17,7 +17,7 @@ def test_version_check(self): v = None """) self.assertTypesMatchPytd(ty, """ - six = ... # type: module + import six v = ... # type: str """) @@ -34,7 +34,7 @@ def test_string_types(self): """) self.assertTypesMatchPytd(ty, """ from typing import List - six: module + import six a: List[str] b: str c: int diff --git a/pytype/tests/test_solver.py b/pytype/tests/test_solver.py index 2674cf692..4d58a45f3 100644 --- a/pytype/tests/test_solver.py +++ b/pytype/tests/test_solver.py @@ -182,7 +182,7 @@ def every(f, array): return all(itertools.chain(f, array)) """) self.assertTypesMatchPytd(ty, """ - itertools = ... # type: module + import itertools def every(f, array) -> bool: ... """) @@ -276,8 +276,8 @@ def f(date): return date.bad_method() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import bad_mod from typing import Any - bad_mod = ... # type: module def f(date) -> Any: ... """) @@ -288,8 +288,7 @@ def bar(l): l.append(collections.defaultdict(int, [(0, 0)])) """) self.assertTypesMatchPytd(ty, """ - import typing - collections = ... # type: module + import collections def bar(l) -> NoneType: ... """) @@ -319,8 +318,8 @@ def f(): return x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import List, Union - foo = ... # type: module def f() -> List[Union[int, str]]: ... """) @@ -336,7 +335,7 @@ def f(x, *args, y: T) -> T: ... x = foo.f(1, y=2j) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo x = ... # type: complex """) diff --git a/pytype/tests/test_special_builtins1.py b/pytype/tests/test_special_builtins1.py index a9558eb98..5a8c173df 100644 --- a/pytype/tests/test_special_builtins1.py +++ b/pytype/tests/test_special_builtins1.py @@ -64,8 +64,8 @@ class Bar(foo.Foo): foo = property(fget=foo.Foo.get_foo) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Annotated - foo = ... # type: module class Bar(foo.Foo): foo = ... # type: Annotated[int, 'property'] """) @@ -94,8 +94,8 @@ class Bar(foo.Foo): foo = property(fget=foo.Foo.get_foo) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Annotated, Union - foo = ... # type: module class Bar(foo.Foo): foo = ... # type: Annotated[Union[int, str], 'property'] """) diff --git a/pytype/tests/test_stdlib1.py b/pytype/tests/test_stdlib1.py index 2e38edd16..fad7db31e 100644 --- a/pytype/tests/test_stdlib1.py +++ b/pytype/tests/test_stdlib1.py @@ -14,7 +14,7 @@ def f(): return ast.parse("True") """) self.assertTypesMatchPytd(ty, """ - ast = ... # type: module + import ast def f() -> _ast.Module: ... """) @@ -23,7 +23,7 @@ def test_urllib(self): import urllib """) self.assertTypesMatchPytd(ty, """ - urllib = ... # type: module + import urllib """) def test_traceback(self): @@ -33,8 +33,8 @@ def f(exc): return traceback.format_exception(*exc) """) self.assertTypesMatchPytd(ty, """ + import traceback from typing import List - traceback = ... # type: module def f(exc) -> List[str]: ... """) @@ -44,8 +44,8 @@ def test_os_walk(self): x = list(os.walk("/tmp")) """, deep=False) self.assertTypesMatchPytd(ty, """ + import os from typing import List, Tuple - os = ... # type: module x = ... # type: List[Tuple[str, List[str], List[str]]] """) @@ -55,7 +55,7 @@ def test_struct(self): x = struct.Struct("b") """, deep=False) self.assertTypesMatchPytd(ty, """ - struct = ... # type: module + import struct x = ... # type: struct.Struct """) @@ -64,7 +64,7 @@ def test_warning(self): import warnings """, deep=False) self.assertTypesMatchPytd(ty, """ - warnings = ... # type: module + import warnings """) def test_path_conf(self): @@ -113,7 +113,7 @@ def test_defaultdict(self): f = collections.defaultdict(default_factory = int) """) self.assertTypesMatchPytd(ty, """ - collections = ... # type: module + import collections a = ... # type: collections.defaultdict[str, int] b = ... # type: collections.defaultdict[str, int] c = ... # type: collections.defaultdict[str, int] @@ -135,8 +135,8 @@ def test_defaultdict_no_factory(self): h = collections.defaultdict(default_factory = None) """) self.assertTypesMatchPytd(ty, """ + import collections from typing import Any - collections = ... # type: module a = ... # type: collections.defaultdict[nothing, nothing] b = ... # type: collections.defaultdict[nothing, nothing] c = ... # type: collections.defaultdict[nothing, Any] @@ -156,8 +156,8 @@ def test_defaultdict_diff_defaults(self): d = collections.defaultdict(int, {1: 'one'}) """) self.assertTypesMatchPytd(ty, """ + import collections from typing import Union - collections = ... # type: module a = ... # type: collections.defaultdict[str, Union[int, str]] b = ... # type: collections.defaultdict[str, Union[int, str]] c = ... # type: collections.defaultdict[str, int] @@ -205,7 +205,7 @@ def test_sys_version_info(self): major, minor, micro, releaselevel, serial = sys.version_info """) self.assertTypesMatchPytd(ty, """ - sys: module + import sys major: int minor: int micro: int diff --git a/pytype/tests/test_stdlib2.py b/pytype/tests/test_stdlib2.py index 43f19fc6b..1e0dd9ee5 100644 --- a/pytype/tests/test_stdlib2.py +++ b/pytype/tests/test_stdlib2.py @@ -27,7 +27,7 @@ def test_collections_deque_init(self): x = collections.deque([1, 2, 3], maxlen=10) """) self.assertTypesMatchPytd(ty, """ - collections = ... # type: module + import collections x = ... # type: collections.deque[int] """) @@ -94,7 +94,7 @@ def foo() -> MyClass: return MyClass(1, 2) """) self.assertTypesMatchPytd(ty, """ - fractions: module + import fractions class MyClass(fractions.Fraction): ... def foo() -> MyClass: ... """) @@ -162,11 +162,10 @@ def f(fi: typing.IO): f(tempfile.SpooledTemporaryFile(1048576, "wb", suffix=".foo")) """, deep=False) self.assertTypesMatchPytd(ty, """ - from typing import Any, Union + import os + import tempfile import typing - os = ... # type: module - tempfile = ... # type: module - typing = ... # type: module + from typing import Any, Union def f(fi: typing.IO) -> Union[bytes, str]: ... """) @@ -210,7 +209,7 @@ def test_sys_version_info_lt(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys v = ... # type: str """) @@ -223,7 +222,7 @@ def test_sys_version_info_le(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys v = ... # type: int """) @@ -238,7 +237,7 @@ def test_sys_version_info_eq(self): v = None """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys v = ... # type: str """) @@ -251,7 +250,7 @@ def test_sys_version_info_ge(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys v = ... # type: int """) @@ -264,7 +263,7 @@ def test_sys_version_info_gt(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys = ... # type: module + import sys v = ... # type: int """) @@ -277,7 +276,7 @@ def test_sys_version_info_named_attribute(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys: module + import sys v: str """) @@ -290,7 +289,7 @@ def test_sys_version_info_tuple(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys: module + import sys v: int """) @@ -303,7 +302,7 @@ def test_sys_version_info_slice(self): v = "hello world" """) self.assertTypesMatchPytd(ty, """ - sys: module + import sys v: int """) @@ -333,10 +332,9 @@ async def test_with(x): event_loop.close() """) self.assertTypesMatchPytd(ty, """ - import asyncio.events + import asyncio from typing import Any, Coroutine - asyncio: module event_loop: asyncio.events.AbstractEventLoop class AsyncContextManager: @@ -369,8 +367,8 @@ async def iterate(x): iterate(AsyncIterable()) """) self.assertTypesMatchPytd(ty, """ + import asyncio from typing import Any, Coroutine, TypeVar - asyncio: module _TAsyncIterable = TypeVar('_TAsyncIterable', bound=AsyncIterable) class AsyncIterable: def __aiter__(self: _TAsyncIterable) -> _TAsyncIterable: ... @@ -395,7 +393,7 @@ def run(cmd): return stdout """) self.assertTypesMatchPytd(ty, """ - subprocess: module + import subprocess def run(cmd) -> bytes: ... """) @@ -408,7 +406,7 @@ def run(cmd): return stdout """) self.assertTypesMatchPytd(ty, """ - subprocess: module + import subprocess def run(cmd) -> bytes: ... """) @@ -422,7 +420,7 @@ def run(cmd): return stdout """) self.assertTypesMatchPytd(ty, """ - subprocess: module + import subprocess def run(cmd) -> bytes: ... """) @@ -435,7 +433,7 @@ def run(cmd): return stdout """) self.assertTypesMatchPytd(ty, """ - subprocess: module + import subprocess def run(cmd) -> str: ... """) @@ -449,7 +447,7 @@ def run(cmd): return stdout """) self.assertTypesMatchPytd(ty, """ - subprocess: module + import subprocess def run(cmd) -> str: ... """) @@ -476,8 +474,8 @@ def test_chainmap(self): v4 = v1.new_child() """) self.assertTypesMatchPytd(ty, """ + import collections from typing import ChainMap, List, Mapping, Union - collections: module v1: ChainMap[Union[bytes, str], Union[int, str]] v2: List[Mapping[Union[bytes, str], Union[int, str]]] v3: ChainMap[Union[bytes, str], Union[int, str]] @@ -493,8 +491,8 @@ def test_re(self): group = match[0] """) self.assertTypesMatchPytd(ty, """ + import re from typing import Match, Optional, Pattern - re: module pattern: Pattern[str] match: Optional[Match[str]] group: str @@ -513,7 +511,7 @@ def f(name): return io.open(name, "rb").read() """) self.assertTypesMatchPytd(ty, """ - io: module + import io def f(name) -> bytes: ... """) diff --git a/pytype/tests/test_super1.py b/pytype/tests/test_super1.py index 85de9a8ef..3431d8de1 100644 --- a/pytype/tests/test_super1.py +++ b/pytype/tests/test_super1.py @@ -181,8 +181,8 @@ def f(self): return super(Parent, self).f() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo = ... # type: module class Parent(foo.Grandparent): ... OtherParent = ... # type: Any class Child(Any, Parent): diff --git a/pytype/tests/test_tuple2.py b/pytype/tests/test_tuple2.py index 649e55f94..cfdb86574 100644 --- a/pytype/tests/test_tuple2.py +++ b/pytype/tests/test_tuple2.py @@ -227,8 +227,8 @@ def test_strptime(self): time.strptime('', '%m %d %Y')[0:5]) """) self.assertTypesMatchPytd(ty, """ + import time from typing import Union - time: module year: int month: int day: int diff --git a/pytype/tests/test_typevar1.py b/pytype/tests/test_typevar1.py index 32d30bdaa..96b35a7b9 100644 --- a/pytype/tests/test_typevar1.py +++ b/pytype/tests/test_typevar1.py @@ -185,7 +185,7 @@ def f(x: AnyInt) -> AnyInt: ... y = 3 """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: int y = ... # type: int """) @@ -208,9 +208,8 @@ class B(A): ... y = a.B().foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - from typing import List import a - a = ... # type: module + from typing import List x = ... # type: List[a.A] y = ... # type: List[a.B] """) @@ -235,9 +234,8 @@ def make_B() -> B[int]: ... y = a.make_B().foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - from typing import List import a - a = ... # type: module + from typing import List x = ... # type: List[a.A[int]] y = ... # type: List[a.B[int]] """) @@ -261,7 +259,7 @@ def make_A() -> A[int]: ... x = a.make_A().foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a = ... # type: module + import a x = ... # type: List[int] """) @@ -283,8 +281,8 @@ def make_A() -> A[int]: ... x = a.make_A().foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import a from typing import List - a = ... # type: module x = ... # type: List[int] """) @@ -306,9 +304,8 @@ class B(A): ... y = a.B().foo() """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - from typing import List import a - a = ... # type: module + from typing import List v = ... # type: List[a.A] w = ... # type: List[a.B] x = ... # type: List[a.A] @@ -332,9 +329,8 @@ class A(metaclass=Meta): x = a.A.foo """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - from typing import List import a - a = ... # type: module + from typing import List x = ... # type: List[a.A] """) diff --git a/pytype/tests/test_typevar2.py b/pytype/tests/test_typevar2.py index 2dba677f7..2807a8deb 100644 --- a/pytype/tests/test_typevar2.py +++ b/pytype/tests/test_typevar2.py @@ -18,8 +18,8 @@ def f(x: T) -> T: w = f("") """) self.assertTypesMatchPytd(ty, """ + import typing from typing import Any - typing = ... # type: module T = TypeVar("T") def f(x: T) -> T: ... v = ... # type: int @@ -247,7 +247,6 @@ def f(x: T) -> T: v = id(x) if x else 42 """, deep=False) self.assertTypesMatchPytd(ty, """ - import types from typing import Optional, TypeVar v = ... # type: int x = ... # type: Optional[int] @@ -341,8 +340,8 @@ def f(x: T) -> T: """) self.assertTypesMatchPytd(ty, """ import __future__ + import typing from typing import Any - typing = ... # type: module unicode_literals = ... # type: __future__._Feature T = TypeVar("T") def f(x: T) -> T: ... @@ -642,8 +641,8 @@ def f(x: T) -> foo.X[T]: return [x] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import List, TypeVar - foo: module T = TypeVar('T') def f(x: T) -> List[T]: ... """) @@ -670,8 +669,8 @@ def f(x: int) -> foo.X[int]: return [x] """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import List - foo: module def f(x: int) -> List[int]: ... """) @@ -727,9 +726,9 @@ def f(x: foo.T, y: bar.X[foo.T]): pass """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import bar + import foo from typing import Callable, TypeVar - foo: module - bar: module T = TypeVar('T') def f(x: T, y: Callable[[T], T]) -> None: ... """) @@ -859,7 +858,7 @@ def f2(x: foo.X[T]) -> T: return x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo from typing import Optional, TypeVar T = TypeVar('T') def f1(x: Optional[int]) -> None: ... @@ -923,8 +922,8 @@ def run(args: List[str]): return result.stdout """) self.assertTypesMatchPytd(ty, """ + import subprocess from typing import List - subprocess: module def run(args: List[str]) -> str: ... """) diff --git a/pytype/tests/test_typing1.py b/pytype/tests/test_typing1.py index 3f2979f90..9bacbbd8c 100644 --- a/pytype/tests/test_typing1.py +++ b/pytype/tests/test_typing1.py @@ -13,8 +13,8 @@ def test_all(self): x = typing.__all__ """, deep=False) self.assertTypesMatchPytd(ty, """ + import typing from typing import List - typing = ... # type: module x = ... # type: List[str] """) @@ -26,8 +26,8 @@ def f(): return typing.cast(typing.List[int], []) """) self.assertTypesMatchPytd(ty, """ + import typing from typing import Any, List - typing = ... # type: module def f() -> List[int]: ... """) @@ -47,7 +47,7 @@ class A: pass """) self.assertTypesMatchPytd(ty, """ - typing: module + import typing v1: None v2: typing.Any v3: A @@ -106,7 +106,6 @@ class Foo(Tuple[Foo]): ... """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ import foo - foo: module x: foo.Foo """) @@ -239,7 +238,7 @@ def f(x: bool) -> complex: ... v3 = foo.f(x) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo x: bool v1: int v2: float @@ -257,7 +256,7 @@ def okay() -> Literal[True]: ... if not foo.okay(): x = "oh no" """, pythonpath=[d.path]) - self.assertTypesMatchPytd(ty, "foo: module") + self.assertTypesMatchPytd(ty, "import foo") def test_pyi_variable(self): with file_utils.Tempdir() as d: @@ -270,7 +269,7 @@ def test_pyi_variable(self): if not foo.OKAY: x = "oh no" """, pythonpath=[d.path]) - self.assertTypesMatchPytd(ty, "foo: module") + self.assertTypesMatchPytd(ty, "import foo") def test_pyi_typing_extensions(self): with file_utils.Tempdir() as d: @@ -283,7 +282,7 @@ def test_pyi_typing_extensions(self): if not foo.OKAY: x = "oh no" """, pythonpath=[d.path]) - self.assertTypesMatchPytd(ty, "foo: module") + self.assertTypesMatchPytd(ty, "import foo") # TODO(b/173742489): Include enums once we support looking up local enums. def test_pyi_value(self): @@ -322,7 +321,7 @@ def f(x) -> str: ... v3 = foo.f(True) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo v1: int v2: int v3: str @@ -342,7 +341,7 @@ def test_reexport(self): """, pythonpath=[d.path]) # TODO(b/123775699): The type of x should be Literal[True]. self.assertTypesMatchPytd(ty, """ - foo: module + import foo x: bool y: None """) @@ -364,7 +363,7 @@ def f3(f): return foo.open(f, mode="rb") """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo def f1(f) -> str: ... def f2(f) -> str: ... def f3(f) -> int: ... @@ -384,8 +383,8 @@ def f(x: Literal[False]) -> str: ... # Inference completing without type errors shows that `__any_object__` # matched both Literal[True] and Literal[False]. self.assertTypesMatchPytd(ty, """ + import foo from typing import Any - foo: module v: Any """) @@ -408,7 +407,7 @@ def f2(): return foo.f(foo.y) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo: module + import foo def f1() -> int: ... def f2() -> str: ... """) diff --git a/pytype/tests/test_typing2.py b/pytype/tests/test_typing2.py index dc6fe2976..feb3471ac 100644 --- a/pytype/tests/test_typing2.py +++ b/pytype/tests/test_typing2.py @@ -261,8 +261,7 @@ def g() -> Any: pass """, deep=False) self.assertTypesMatchPytd(ty, """ - import __future__ - typing = ... # type: module + import typing def f() -> typing.Any: ... def g() -> Any: ... class Any: @@ -632,9 +631,9 @@ def freqs(s: str) -> typing.Counter[str]: f = x | z """) self.assertTypesMatchPytd(ty, """ + import collections + import typing from typing import Counter, Iterable, List, Tuple, Union - collections: module - typing: module a: List[Tuple[str, int]] b: List[Tuple[str, int]] @@ -668,8 +667,8 @@ def f() -> NamedTuple("ret", [("x", int), ("y", str)]): ... z = foo.f()[2] # out of bounds, fall back to the combined element type """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo: module w: str x: int y: str diff --git a/pytype/tests/test_typing_annotated.py b/pytype/tests/test_typing_annotated.py index 5261ff166..d2c22d084 100644 --- a/pytype/tests/test_typing_annotated.py +++ b/pytype/tests/test_typing_annotated.py @@ -74,7 +74,7 @@ class A: x = a.A().x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a: module + import a x: int """) @@ -92,7 +92,7 @@ class A: x = a.A().x.w """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a: module + import a x: int """) @@ -110,7 +110,7 @@ class B(a.A): x = B().x """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - a: module + import a class B(a.A): ... x: int """) diff --git a/pytype/tests/test_typing_methods1.py b/pytype/tests/test_typing_methods1.py index 1d4e887ba..5137cba3a 100644 --- a/pytype/tests/test_typing_methods1.py +++ b/pytype/tests/test_typing_methods1.py @@ -89,9 +89,9 @@ def f() -> IO[str]: ... x.close() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import IO, List fi = ... # type: IO[str] - foo = ... # type: module a = ... # type: int b = ... # type: bool c = ... # type: str @@ -132,8 +132,8 @@ def tpl() -> Tuple[str]: ... f = reversed(seq) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Iterator, List, Sequence, Tuple, Union - foo = ... # type: module seq = ... # type: Union[Sequence[str], Tuple[str]] a = ... # type: str b = ... # type: int @@ -165,8 +165,8 @@ def lst() -> List[str]: ... b = seq.extend([1,2,3]) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Iterator, List, Sequence, Union - foo = ... # type: module # TODO(b/159065400): Should be List[Union[int, str]] seq = ... # type: Union[list, typing.MutableSequence[Union[int, str]]] a = ... # type: None @@ -198,7 +198,7 @@ def deq() -> Deque[int]: ... d = q.rotate(3) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ - foo = ... # type: module + import foo from typing import Deque q = ... # type: Deque[int] a = ... # type: None @@ -234,7 +234,6 @@ def f() -> MyDict[str, int]: ... self.assertTypesMatchPytd(ty, """ from typing import Tuple, Union import foo - foo = ... # type: module m = ... # type: foo.MyDict[Union[complex, int, str], Union[complex, float, int]] a = ... # type: Union[complex, float, int] b = ... # type: Tuple[Union[complex, str], Union[float, int]] @@ -263,8 +262,8 @@ def f() -> AbstractSet[str]: ... f = x.isdisjoint([1,2,3]) """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import AbstractSet - foo = ... # type: module x = ... # type: AbstractSet[str] a = ... # type: bool b = ... # type: AbstractSet[str] @@ -299,8 +298,8 @@ def f() -> MutableSet[str]: ... e = 3 in x """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import MutableSet, Union - foo = ... # type: module a = ... # type: Union[int, str] # TODO(b/159067449): We do a clear() after adding "int". # Why does "int" still appear for b? @@ -347,8 +346,8 @@ def f() -> Pattern[str]: ... e = m1.end() """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import List, Match, Pattern - foo = ... # type: module a = ... # type: int b = ... # type: int c = ... # type: str diff --git a/pytype/tests/test_typing_methods2.py b/pytype/tests/test_typing_methods2.py index 194485322..80607c00e 100644 --- a/pytype/tests/test_typing_methods2.py +++ b/pytype/tests/test_typing_methods2.py @@ -30,7 +30,6 @@ def f() -> MyDict[str, int]: ... self.assertTypesMatchPytd(ty, """ from typing import List, Tuple, Union import foo - foo = ... # type: module m = ... # type: foo.MyDict[str, int] a = ... # type: typing.Mapping[str, int] b = ... # type: bool diff --git a/pytype/tests/test_typing_namedtuple1.py b/pytype/tests/test_typing_namedtuple1.py index 31a7d1e2f..30899c698 100644 --- a/pytype/tests/test_typing_namedtuple1.py +++ b/pytype/tests/test_typing_namedtuple1.py @@ -182,8 +182,8 @@ def test_unpacking(self): """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module v = ... # type: foo.namedtuple_X_0 a = ... # type: str b = ... # type: int diff --git a/pytype/tests/test_typing_namedtuple2.py b/pytype/tests/test_typing_namedtuple2.py index 2ac4e867d..a3926d4de 100644 --- a/pytype/tests/test_typing_namedtuple2.py +++ b/pytype/tests/test_typing_namedtuple2.py @@ -79,8 +79,8 @@ def test_basic_namedtuple(self): ty, """ import collections + import typing from typing import Callable, Iterable, Sized, Tuple, Type, TypeVar, Union - typing = ... # type: module x = ... # type: X a = ... # type: int b = ... # type: str @@ -275,8 +275,8 @@ class X(NamedTuple): """, deep=False, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Union - foo = ... # type: module v = ... # type: foo.X a = ... # type: str b = ... # type: int diff --git a/pytype/tests/test_unpack.py b/pytype/tests/test_unpack.py index 3e46aaf7f..693737ee8 100644 --- a/pytype/tests/test_unpack.py +++ b/pytype/tests/test_unpack.py @@ -40,8 +40,8 @@ def test_unpack_indefinite_from_pytd(self): c = (*foo.a, *foo.b) """, pythonpath=[d.path]) self.assertTypesMatchPytd(ty, """ + import foo from typing import Tuple, Union - foo: module c: Tuple[Union[int, str], ...] """) diff --git a/pytype/tools/analyze_project/config.py b/pytype/tools/analyze_project/config.py index 2b38b1ee6..d35729b5b 100644 --- a/pytype/tools/analyze_project/config.py +++ b/pytype/tools/analyze_project/config.py @@ -68,6 +68,8 @@ 'build_dict_literals_from_kwargs': Item( None, 'False', ArgInfo('--build-dict-literals-from-kwargs', None), None), + 'gen_stub_imports': Item( + None, 'False', ArgInfo('--gen-stub-imports', None), None), 'disable': Item( None, 'pyi-error', ArgInfo('--disable', ','.join), 'Comma or space separated list of error names to ignore.'), diff --git a/pytype/tools/traces/traces_test.py b/pytype/tools/traces/traces_test.py index 7d0de5381..2ff1dcee0 100644 --- a/pytype/tools/traces/traces_test.py +++ b/pytype/tools/traces/traces_test.py @@ -83,6 +83,10 @@ class MatchAstTestCase(unittest.TestCase): def _parse(self, text, options=None): text = textwrap.dedent(text).lstrip() + if options: + options.tweak(gen_stub_imports=True) + else: + options = config.Options.create(gen_stub_imports=True) return ast.parse(text), traces.trace(text, options) def _get_traces(self, text, node_type, options=None): @@ -117,14 +121,14 @@ def test_not_implemented(self): def test_import(self): matches = self._get_traces("import os, sys as tzt", ast.Import) self.assertTracesEqual(matches, [ - ((1, 7), "IMPORT_NAME", "os", ("module",)), - ((1, 18), "STORE_NAME", "tzt", ("module",))]) + ((1, 7), "IMPORT_NAME", "os", ("import os",)), + ((1, 18), "STORE_NAME", "tzt", ("import sys",))]) def test_import_from(self): matches = self._get_traces( "from os import path as p, environ", ast.ImportFrom) self.assertTracesEqual(matches, [ - ((1, 23), "STORE_NAME", "p", ("module",)), + ((1, 23), "STORE_NAME", "p", ("import os.path",)), ((1, 26), "STORE_NAME", "environ", ("os._Environ[str]",))]) diff --git a/pytype/vm.py b/pytype/vm.py index d7e42cc04..65117b131 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1061,11 +1061,11 @@ def trace_namedtuple(self, *args): return NotImplemented def call_init(self, node, unused_instance): - # This dummy implementation is overwritten in analyze.py. + # This dummy implementation is overwritten in tracer_vm.py. return node def init_class(self, node, cls, extra_key=None): - # This dummy implementation is overwritten in analyze.py. + # This dummy implementation is overwritten in tracer_vm.py. del cls, extra_key return node, None @@ -3247,7 +3247,7 @@ def byte_END_FINALLY(self, state, op): return state.set_why("reraise") def _check_return(self, node, actual, formal): - return False # overridden in analyze.py + return False # overwritten in tracer_vm.py def _set_frame_return(self, node, frame, var): if frame.allowed_returns is not None: @@ -3624,7 +3624,10 @@ def byte_YIELD_FROM(self, state, op): ret_var = val.get_instance_type_parameter(abstract_utils.T) else: ret_var = val.get_instance_type_parameter(abstract_utils.V) - result.PasteVariable(ret_var, state.node, {b}) + if ret_var.bindings: + result.PasteVariable(ret_var, state.node, {b}) + else: + result.AddBinding(self.ctx.convert.unsolvable, {b}, state.node) else: result.AddBinding(val, {b}, state.node) return state.push(result)