Skip to content

Commit

Permalink
Merge pull request #397 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 committed Aug 28, 2019
2 parents 0e8e75a + afdbe75 commit a5c85aa
Show file tree
Hide file tree
Showing 21 changed files with 573 additions and 121 deletions.
44 changes: 39 additions & 5 deletions pytype/abstract.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -2162,6 +2162,25 @@ def get_special_attribute(self, node, name, valself):
return super(CallableClass, self).get_special_attribute(node, name, valself) return super(CallableClass, self).get_special_attribute(node, name, valself)




class LiteralClass(ParameterizedClass):
"""The class of a typing.Literal."""

def __init__(self, base_cls, instance, vm):
formal_type_parameters = {abstract_utils.T: instance.get_class()}
super(LiteralClass, self).__init__(base_cls, formal_type_parameters, vm)
self._instance = instance

def __repr__(self):
return "LiteralClass(%s)" % self._instance

@property
def value(self):
if isinstance(self._instance, AbstractOrConcreteValue):
return self._instance
# TODO(b/123775699): Remove this workaround once we support literal enums.
return None


class PyTDClass(SimpleAbstractValue, mixin.Class): class PyTDClass(SimpleAbstractValue, mixin.Class):
"""An abstract wrapper for PyTD class objects. """An abstract wrapper for PyTD class objects.
Expand Down Expand Up @@ -2315,15 +2334,30 @@ def __init__(self, name, bases, members, cls, vm):


def type_param_check(self): def type_param_check(self):
"""Throw exception for invalid type parameters.""" """Throw exception for invalid type parameters."""

def update_sig(method):
method.signature.excluded_types.update(
[t.name for t in self.template])
method.signature.add_scope(self.full_name)

if self.template: if self.template:
# For function type parameters check # For function type parameters check
for mbr in self.members.values(): for mbr in self.members.values():
mbr = abstract_utils.get_atomic_value( m = abstract_utils.get_atomic_value(
mbr, default=self.vm.convert.unsolvable) mbr, default=self.vm.convert.unsolvable)
if isinstance(mbr, InterpreterFunction): if isinstance(m, InterpreterFunction):
mbr.signature.excluded_types.update( update_sig(m)
[t.name for t in self.template]) elif mbr.data and all(
mbr.signature.add_scope(self.full_name) x.__class__.__name__ == "PropertyInstance" for x in mbr.data):
# We generate a new variable every time we add a property slot, so we
# take the last one (which contains bindings for all defined slots).
prop = mbr.data[-1]
for slot in (prop.fget, prop.fset, prop.fdel):
if slot:
for d in slot.data:
if isinstance(d, InterpreterFunction):
update_sig(d)

# nested class can not use the same type parameter # nested class can not use the same type parameter
# in current generic class # in current generic class
inner_cls_types = self.collect_inner_cls_types() inner_cls_types = self.collect_inner_cls_types()
Expand Down
7 changes: 6 additions & 1 deletion pytype/annotations_util.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def contains(subst, annot):
node, param, substs, instantiate_unbound) node, param, substs, instantiate_unbound)
for name, param in annot.formal_type_parameters.items()} for name, param in annot.formal_type_parameters.items()}
# annot may be a subtype of ParameterizedClass, such as TupleClass. # annot may be a subtype of ParameterizedClass, such as TupleClass.
return type(annot)(annot.base_cls, type_parameters, self.vm) if isinstance(annot, abstract.LiteralClass):
# We can't create a LiteralClass because we don't have a concrete value.
typ = abstract.ParameterizedClass
else:
typ = type(annot)
return typ(annot.base_cls, type_parameters, self.vm)
elif isinstance(annot, abstract.Union): elif isinstance(annot, abstract.Union):
options = tuple(self.sub_one_annotation(node, o, substs, options = tuple(self.sub_one_annotation(node, o, substs,
instantiate_unbound) instantiate_unbound)
Expand Down
4 changes: 4 additions & 0 deletions pytype/config.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def add_basic_options(o):
"--strict-import", action="store_true", "--strict-import", action="store_true",
dest="strict_import", default=False, dest="strict_import", default=False,
help="Experimental: Only load submodules that are explicitly imported.") help="Experimental: Only load submodules that are explicitly imported.")
o.add_argument(
"--precise-return", action="store_true", dest="precise_return",
default=False, help=("Experimental: Infer precise return types even for "
"invalid function calls."))




def add_subtools(o): def add_subtools(o):
Expand Down
5 changes: 2 additions & 3 deletions pytype/convert.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -776,11 +776,10 @@ def _constant_to_value(self, pyval, subst, get_node):
template, parameters, subst) template, parameters, subst)
return abstract_class(base_cls, type_parameters, self.vm) return abstract_class(base_cls, type_parameters, self.vm)
elif isinstance(pyval, pytd.Literal): elif isinstance(pyval, pytd.Literal):
# TODO(b/123775699): Create a ParameterizedClass(Literal) to record that
# this type is a literal.
value = self.constant_to_value( value = self.constant_to_value(
self._get_literal_value(pyval.value), subst, self.vm.root_cfg_node) self._get_literal_value(pyval.value), subst, self.vm.root_cfg_node)
return value.get_class() return abstract.LiteralClass(
self.name_to_value("typing.Literal"), value, self.vm)
elif pyval.__class__ is tuple: # only match raw tuple, not namedtuple/Node elif pyval.__class__ is tuple: # only match raw tuple, not namedtuple/Node
return self.tuple_to_value([self.constant_to_var(item, subst, return self.tuple_to_value([self.constant_to_var(item, subst,
self.vm.root_cfg_node) self.vm.root_cfg_node)
Expand Down
81 changes: 49 additions & 32 deletions pytype/function.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -371,9 +371,25 @@ def get_variables(self):
return variables return variables




class ReturnValueMixin(object):
"""Mixin for exceptions that hold a return node and variable."""

def __init__(self):
super(ReturnValueMixin, self).__init__()
self.return_node = None
self.return_variable = None

def set_return(self, node, var):
self.return_node = node
self.return_variable = var

def get_return(self, state):
return state.change_cfg_node(self.return_node), self.return_variable


# These names are chosen to match pytype error classes. # These names are chosen to match pytype error classes.
# pylint: disable=g-bad-exception-name # pylint: disable=g-bad-exception-name
class FailedFunctionCall(Exception): class FailedFunctionCall(Exception, ReturnValueMixin):
"""Exception for failed function calls.""" """Exception for failed function calls."""


def __gt__(self, other): def __gt__(self, other):
Expand All @@ -388,7 +404,7 @@ def __init__(self, obj):
self.obj = obj self.obj = obj




class DictKeyMissing(Exception): class DictKeyMissing(Exception, ReturnValueMixin):
"""When retrieving a key that does not exist in a dict.""" """When retrieving a key that does not exist in a dict."""


def __init__(self, name): def __init__(self, name):
Expand Down Expand Up @@ -580,48 +596,49 @@ def substitute_formal_args(self, node, args, view, alias_map):


return arg_dict, subst return arg_dict, subst


def instantiate_return(self, node, subst, sources):
return_type = self.pytd_sig.return_type
for param in pytd_utils.GetTypeParameters(return_type):
if param.full_name in subst:
# This value, which was instantiated by the matcher, will end up in the
# return value. Since the matcher does not call __init__, we need to do
# that now.
node = self.vm.call_init(node, subst[param.full_name])
try:
ret = self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
except self.vm.convert.TypeParameterError:
# The return type contains a type parameter without a substitution.
subst = subst.copy()
visitor = visitors.CollectTypeParameters()
return_type.Visit(visitor)

for t in visitor.params:
if t.full_name not in subst:
subst[t.full_name] = self.vm.convert.empty.to_variable(node)
return node, self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
if not ret.bindings and isinstance(return_type, pytd.TypeParameter):
ret.AddBinding(self.vm.convert.empty, [], node)
return node, ret

def call_with_args(self, node, func, arg_dict, def call_with_args(self, node, func, arg_dict,
subst, ret_map, alias_map=None): subst, ret_map, alias_map=None):
"""Call this signature. Used by PyTDFunction.""" """Call this signature. Used by PyTDFunction."""
return_type = self.pytd_sig.return_type t = (self.pytd_sig.return_type, subst)
t = (return_type, subst)
sources = [func] + list(arg_dict.values()) sources = [func] + list(arg_dict.values())
if t not in ret_map: if t not in ret_map:
for param in pytd_utils.GetTypeParameters(return_type): node, ret_map[t] = self.instantiate_return(node, subst, sources)
if param.full_name in subst:
# This value, which was instantiated by the matcher, will end up in
# the return value. Since the matcher does not call __init__, we need
# to do that now.
node = self.vm.call_init(node, subst[param.full_name])
try:
ret_map[t] = self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
except self.vm.convert.TypeParameterError:
# The return type contains a type parameter without a substitution.
subst = subst.copy()
visitor = visitors.CollectTypeParameters()
return_type.Visit(visitor)

for t in visitor.params:
if t.full_name not in subst:
subst[t.full_name] = self.vm.convert.empty.to_variable(node)
ret_map[t] = self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
else:
if (not ret_map[t].bindings and
isinstance(return_type, pytd.TypeParameter)):
ret_map[t].AddBinding(self.vm.convert.empty, [], node)
else: else:
# add the new sources # add the new sources
for data in ret_map[t].data: for data in ret_map[t].data:
ret_map[t].AddBinding(data, sources, node) ret_map[t].AddBinding(data, sources, node)
mutations = self._get_mutation(node, arg_dict, subst) mutations = self._get_mutation(node, arg_dict, subst)
self.vm.trace_call(node, func, (self,), self.vm.trace_call(node, func, (self,),
tuple(arg_dict[p.name] for p in self.pytd_sig.params), tuple(arg_dict[p.name] for p in self.pytd_sig.params),
{}, {}, ret_map[t])
ret_map[t])
return node, ret_map[t], mutations return node, ret_map[t], mutations


def _get_mutation(self, node, arg_dict, subst): def _get_mutation(self, node, arg_dict, subst):
Expand Down
53 changes: 50 additions & 3 deletions pytype/matcher.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


from pytype import abstract from pytype import abstract
from pytype import abstract_utils from pytype import abstract_utils
from pytype import compat
from pytype import datatypes from pytype import datatypes
from pytype import function from pytype import function
from pytype import mixin from pytype import mixin
Expand All @@ -13,6 +14,7 @@
from pytype.pytd import pep484 from pytype.pytd import pep484
from pytype.pytd import pytd from pytype.pytd import pytd
from pytype.pytd import pytd_utils from pytype.pytd import pytd_utils
from pytype.pytd.parse import parser_constants




log = logging.getLogger(__name__) log = logging.getLogger(__name__)
Expand Down Expand Up @@ -632,6 +634,33 @@ def _match_callable_instance(
return None return None
return subst return subst


def _match_pyval_against_string(self, pyval, string, subst):
"""Matches a concrete value against a string literal."""
assert isinstance(string, str)

if pyval.__class__ is str: # native str
left_type = "bytes" if self.vm.PY2 else "unicode"
elif isinstance(pyval, compat.BytesType):
left_type = "bytes"
elif isinstance(pyval, compat.UnicodeType):
left_type = "unicode"
else:
return None
# needs to be native str to match `string`
left_value = compat.native_str(pyval)

right_prefix, right_value = (
parser_constants.STRING_RE.match(string).groups()[:2])
if "b" in right_prefix or "u" not in right_prefix and self.vm.PY2:
right_type = "bytes"
else:
right_type = "unicode"
right_value = right_value[1:-1] # remove quotation marks

if left_type == right_type and left_value == right_value:
return subst
return None

def _match_class_and_instance_against_type( def _match_class_and_instance_against_type(
self, left, instance, other_type, subst, node, view): self, left, instance, other_type, subst, node, view):
"""Checks whether an instance of a type is compatible with a (formal) type. """Checks whether an instance of a type is compatible with a (formal) type.
Expand All @@ -646,7 +675,25 @@ def _match_class_and_instance_against_type(
Returns: Returns:
A new type parameter assignment if the matching succeeded, None otherwise. A new type parameter assignment if the matching succeeded, None otherwise.
""" """
if isinstance(other_type, mixin.Class): if isinstance(other_type, abstract.LiteralClass):
other_value = other_type.value
if other_value and isinstance(instance, abstract.AbstractOrConcreteValue):
if isinstance(other_value.pyval, str):
return self._match_pyval_against_string(
instance.pyval, other_value.pyval, subst)
return subst if instance.pyval == other_value.pyval else None
elif other_value:
# `instance` does not contain a concrete value. Literal overloads are
# always followed by at least one non-literal fallback, so we should
# fail here.
return None
else:
# TODO(b/123775699): Remove this workaround once we can match against
# literal enums.
return self._match_type_against_type(
instance, other_type.formal_type_parameters[abstract_utils.T],
subst, node, view)
elif isinstance(other_type, mixin.Class):
base = self.match_from_mro(left, other_type) base = self.match_from_mro(left, other_type)
if base is None: if base is None:
if other_type.is_protocol: if other_type.is_protocol:
Expand Down Expand Up @@ -830,8 +877,8 @@ def _enforce_common_superclass(self, var):
for cls in classes: for cls in classes:
object_in_values |= cls == self.vm.convert.object_type object_in_values |= cls == self.vm.convert.object_type
superclasses = {c.full_name for c in cls.mro} superclasses = {c.full_name for c in cls.mro}
for compat, name in _COMPATIBLE_BUILTINS: for compat_name, name in _COMPATIBLE_BUILTINS:
if compat in superclasses: if compat_name in superclasses:
superclasses.add(name) superclasses.add(name)
if common_classes is None: if common_classes is None:
common_classes = superclasses common_classes = superclasses
Expand Down
4 changes: 1 addition & 3 deletions pytype/pyi/parser.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


import collections import collections
import hashlib import hashlib
import re


from pytype import file_utils from pytype import file_utils
from pytype import module_utils from pytype import module_utils
Expand All @@ -19,7 +18,6 @@
_DEFAULT_PLATFORM = "linux" _DEFAULT_PLATFORM = "linux"
# Typing members that represent sets of types. # Typing members that represent sets of types.
_TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union") _TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union")
_STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$")




_Params = collections.namedtuple("_", ["required", _Params = collections.namedtuple("_", ["required",
Expand Down Expand Up @@ -1619,7 +1617,7 @@ def _make_type_type(value):
def _handle_string_literal(value): def _handle_string_literal(value):
if not isinstance(value, str): if not isinstance(value, str):
return value return value
match = _STRING_RE.match(value) match = parser_constants.STRING_RE.match(value)
if not match: if not match:
return value return value
return match.groups()[1][1:-1] return match.groups()[1][1:-1]
3 changes: 3 additions & 0 deletions pytype/pytd/parse/parser_constants.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@
# Marks external NamedTypes so that they do not get prefixed by the current # Marks external NamedTypes so that they do not get prefixed by the current
# module name. # module name.
EXTERNAL_NAME_PREFIX = '$external$' EXTERNAL_NAME_PREFIX = '$external$'

# Regex for string literals.
STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$")
13 changes: 13 additions & 0 deletions pytype/tests/py3/test_classes.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@ class Foo(object):
bar: int bar: int
""") """)


def testGenericSuper(self):
self.Check("""
from typing import Callable, Generic, TypeVar
T = TypeVar('T')
Func = Callable[[T], str]
class Foo(Generic[T]):
def __init__(self, func: Func = str) -> None:
super(Foo, self).__init__()
self._func = func
def f(self, value: T) -> str:
return self._func(value)
""")



class ClassesTestPython3Feature(test_base.TargetPython3FeatureTest): class ClassesTestPython3Feature(test_base.TargetPython3FeatureTest):
"""Tests for classes.""" """Tests for classes."""
Expand Down
Loading

0 comments on commit a5c85aa

Please sign in to comment.