Skip to content
Permalink
Browse files

Merge pull request #397 from google/google_sync

Google sync
  • Loading branch information...
rchen152 committed Aug 28, 2019
2 parents 0e8e75a + afdbe75 commit a5c85aaca0534730974c1c6fe6f625b1e413fa7c
@@ -2162,6 +2162,25 @@ def get_special_attribute(self, node, name, valself):
return super(CallableClass, self).get_special_attribute(node, name, valself)


class LiteralClass(ParameterizedClass):
"""The class of a typing.Literal."""

def __init__(self, base_cls, instance, vm):
formal_type_parameters = {abstract_utils.T: instance.get_class()}
super(LiteralClass, self).__init__(base_cls, formal_type_parameters, vm)
self._instance = instance

def __repr__(self):
return "LiteralClass(%s)" % self._instance

@property
def value(self):
if isinstance(self._instance, AbstractOrConcreteValue):
return self._instance
# TODO(b/123775699): Remove this workaround once we support literal enums.
return None


class PyTDClass(SimpleAbstractValue, mixin.Class):
"""An abstract wrapper for PyTD class objects.
@@ -2315,15 +2334,30 @@ def __init__(self, name, bases, members, cls, vm):

def type_param_check(self):
"""Throw exception for invalid type parameters."""

def update_sig(method):
method.signature.excluded_types.update(
[t.name for t in self.template])
method.signature.add_scope(self.full_name)

if self.template:
# For function type parameters check
for mbr in self.members.values():
mbr = abstract_utils.get_atomic_value(
m = abstract_utils.get_atomic_value(
mbr, default=self.vm.convert.unsolvable)
if isinstance(mbr, InterpreterFunction):
mbr.signature.excluded_types.update(
[t.name for t in self.template])
mbr.signature.add_scope(self.full_name)
if isinstance(m, InterpreterFunction):
update_sig(m)
elif mbr.data and all(
x.__class__.__name__ == "PropertyInstance" for x in mbr.data):
# We generate a new variable every time we add a property slot, so we
# take the last one (which contains bindings for all defined slots).
prop = mbr.data[-1]
for slot in (prop.fget, prop.fset, prop.fdel):
if slot:
for d in slot.data:
if isinstance(d, InterpreterFunction):
update_sig(d)

# nested class can not use the same type parameter
# in current generic class
inner_cls_types = self.collect_inner_cls_types()
@@ -64,7 +64,12 @@ def contains(subst, annot):
node, param, substs, instantiate_unbound)
for name, param in annot.formal_type_parameters.items()}
# annot may be a subtype of ParameterizedClass, such as TupleClass.
return type(annot)(annot.base_cls, type_parameters, self.vm)
if isinstance(annot, abstract.LiteralClass):
# We can't create a LiteralClass because we don't have a concrete value.
typ = abstract.ParameterizedClass
else:
typ = type(annot)
return typ(annot.base_cls, type_parameters, self.vm)
elif isinstance(annot, abstract.Union):
options = tuple(self.sub_one_annotation(node, o, substs,
instantiate_unbound)
@@ -133,6 +133,10 @@ def add_basic_options(o):
"--strict-import", action="store_true",
dest="strict_import", default=False,
help="Experimental: Only load submodules that are explicitly imported.")
o.add_argument(
"--precise-return", action="store_true", dest="precise_return",
default=False, help=("Experimental: Infer precise return types even for "
"invalid function calls."))


def add_subtools(o):
@@ -776,11 +776,10 @@ def _constant_to_value(self, pyval, subst, get_node):
template, parameters, subst)
return abstract_class(base_cls, type_parameters, self.vm)
elif isinstance(pyval, pytd.Literal):
# TODO(b/123775699): Create a ParameterizedClass(Literal) to record that
# this type is a literal.
value = self.constant_to_value(
self._get_literal_value(pyval.value), subst, self.vm.root_cfg_node)
return value.get_class()
return abstract.LiteralClass(
self.name_to_value("typing.Literal"), value, self.vm)
elif pyval.__class__ is tuple: # only match raw tuple, not namedtuple/Node
return self.tuple_to_value([self.constant_to_var(item, subst,
self.vm.root_cfg_node)
@@ -371,9 +371,25 @@ def get_variables(self):
return variables


class ReturnValueMixin(object):
"""Mixin for exceptions that hold a return node and variable."""

def __init__(self):
super(ReturnValueMixin, self).__init__()
self.return_node = None
self.return_variable = None

def set_return(self, node, var):
self.return_node = node
self.return_variable = var

def get_return(self, state):
return state.change_cfg_node(self.return_node), self.return_variable


# These names are chosen to match pytype error classes.
# pylint: disable=g-bad-exception-name
class FailedFunctionCall(Exception):
class FailedFunctionCall(Exception, ReturnValueMixin):
"""Exception for failed function calls."""

def __gt__(self, other):
@@ -388,7 +404,7 @@ def __init__(self, obj):
self.obj = obj


class DictKeyMissing(Exception):
class DictKeyMissing(Exception, ReturnValueMixin):
"""When retrieving a key that does not exist in a dict."""

def __init__(self, name):
@@ -580,48 +596,49 @@ def substitute_formal_args(self, node, args, view, alias_map):

return arg_dict, subst

def instantiate_return(self, node, subst, sources):
return_type = self.pytd_sig.return_type
for param in pytd_utils.GetTypeParameters(return_type):
if param.full_name in subst:
# This value, which was instantiated by the matcher, will end up in the
# return value. Since the matcher does not call __init__, we need to do
# that now.
node = self.vm.call_init(node, subst[param.full_name])
try:
ret = self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
except self.vm.convert.TypeParameterError:
# The return type contains a type parameter without a substitution.
subst = subst.copy()
visitor = visitors.CollectTypeParameters()
return_type.Visit(visitor)

for t in visitor.params:
if t.full_name not in subst:
subst[t.full_name] = self.vm.convert.empty.to_variable(node)
return node, self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
if not ret.bindings and isinstance(return_type, pytd.TypeParameter):
ret.AddBinding(self.vm.convert.empty, [], node)
return node, ret

def call_with_args(self, node, func, arg_dict,
subst, ret_map, alias_map=None):
"""Call this signature. Used by PyTDFunction."""
return_type = self.pytd_sig.return_type
t = (return_type, subst)
t = (self.pytd_sig.return_type, subst)
sources = [func] + list(arg_dict.values())
if t not in ret_map:
for param in pytd_utils.GetTypeParameters(return_type):
if param.full_name in subst:
# This value, which was instantiated by the matcher, will end up in
# the return value. Since the matcher does not call __init__, we need
# to do that now.
node = self.vm.call_init(node, subst[param.full_name])
try:
ret_map[t] = self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
except self.vm.convert.TypeParameterError:
# The return type contains a type parameter without a substitution.
subst = subst.copy()
visitor = visitors.CollectTypeParameters()
return_type.Visit(visitor)

for t in visitor.params:
if t.full_name not in subst:
subst[t.full_name] = self.vm.convert.empty.to_variable(node)
ret_map[t] = self.vm.convert.constant_to_var(
abstract_utils.AsReturnValue(return_type), subst, node,
source_sets=[sources])
else:
if (not ret_map[t].bindings and
isinstance(return_type, pytd.TypeParameter)):
ret_map[t].AddBinding(self.vm.convert.empty, [], node)
node, ret_map[t] = self.instantiate_return(node, subst, sources)
else:
# add the new sources
for data in ret_map[t].data:
ret_map[t].AddBinding(data, sources, node)
mutations = self._get_mutation(node, arg_dict, subst)
self.vm.trace_call(node, func, (self,),
tuple(arg_dict[p.name] for p in self.pytd_sig.params),
{},
ret_map[t])
{}, ret_map[t])
return node, ret_map[t], mutations

def _get_mutation(self, node, arg_dict, subst):
@@ -4,6 +4,7 @@

from pytype import abstract
from pytype import abstract_utils
from pytype import compat
from pytype import datatypes
from pytype import function
from pytype import mixin
@@ -13,6 +14,7 @@
from pytype.pytd import pep484
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd.parse import parser_constants


log = logging.getLogger(__name__)
@@ -632,6 +634,33 @@ def _match_callable_instance(
return None
return subst

def _match_pyval_against_string(self, pyval, string, subst):
"""Matches a concrete value against a string literal."""
assert isinstance(string, str)

if pyval.__class__ is str: # native str
left_type = "bytes" if self.vm.PY2 else "unicode"
elif isinstance(pyval, compat.BytesType):
left_type = "bytes"
elif isinstance(pyval, compat.UnicodeType):
left_type = "unicode"
else:
return None
# needs to be native str to match `string`
left_value = compat.native_str(pyval)

right_prefix, right_value = (
parser_constants.STRING_RE.match(string).groups()[:2])
if "b" in right_prefix or "u" not in right_prefix and self.vm.PY2:
right_type = "bytes"
else:
right_type = "unicode"
right_value = right_value[1:-1] # remove quotation marks

if left_type == right_type and left_value == right_value:
return subst
return None

def _match_class_and_instance_against_type(
self, left, instance, other_type, subst, node, view):
"""Checks whether an instance of a type is compatible with a (formal) type.
@@ -646,7 +675,25 @@ def _match_class_and_instance_against_type(
Returns:
A new type parameter assignment if the matching succeeded, None otherwise.
"""
if isinstance(other_type, mixin.Class):
if isinstance(other_type, abstract.LiteralClass):
other_value = other_type.value
if other_value and isinstance(instance, abstract.AbstractOrConcreteValue):
if isinstance(other_value.pyval, str):
return self._match_pyval_against_string(
instance.pyval, other_value.pyval, subst)
return subst if instance.pyval == other_value.pyval else None
elif other_value:
# `instance` does not contain a concrete value. Literal overloads are
# always followed by at least one non-literal fallback, so we should
# fail here.
return None
else:
# TODO(b/123775699): Remove this workaround once we can match against
# literal enums.
return self._match_type_against_type(
instance, other_type.formal_type_parameters[abstract_utils.T],
subst, node, view)
elif isinstance(other_type, mixin.Class):
base = self.match_from_mro(left, other_type)
if base is None:
if other_type.is_protocol:
@@ -830,8 +877,8 @@ def _enforce_common_superclass(self, var):
for cls in classes:
object_in_values |= cls == self.vm.convert.object_type
superclasses = {c.full_name for c in cls.mro}
for compat, name in _COMPATIBLE_BUILTINS:
if compat in superclasses:
for compat_name, name in _COMPATIBLE_BUILTINS:
if compat_name in superclasses:
superclasses.add(name)
if common_classes is None:
common_classes = superclasses
@@ -2,7 +2,6 @@

import collections
import hashlib
import re

from pytype import file_utils
from pytype import module_utils
@@ -19,7 +18,6 @@
_DEFAULT_PLATFORM = "linux"
# Typing members that represent sets of types.
_TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union")
_STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$")


_Params = collections.namedtuple("_", ["required",
@@ -1619,7 +1617,7 @@ def _make_type_type(value):
def _handle_string_literal(value):
if not isinstance(value, str):
return value
match = _STRING_RE.match(value)
match = parser_constants.STRING_RE.match(value)
if not match:
return value
return match.groups()[1][1:-1]
@@ -57,3 +57,6 @@
# Marks external NamedTypes so that they do not get prefixed by the current
# module name.
EXTERNAL_NAME_PREFIX = '$external$'

# Regex for string literals.
STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$")
@@ -194,6 +194,19 @@ class Foo(object):
bar: int
""")

def testGenericSuper(self):
self.Check("""
from typing import Callable, Generic, TypeVar
T = TypeVar('T')
Func = Callable[[T], str]
class Foo(Generic[T]):
def __init__(self, func: Func = str) -> None:
super(Foo, self).__init__()
self._func = func
def f(self, value: T) -> str:
return self._func(value)
""")


class ClassesTestPython3Feature(test_base.TargetPython3FeatureTest):
"""Tests for classes."""

0 comments on commit a5c85aa

Please sign in to comment.
You can’t perform that action at this time.