Skip to content
Browse files

Support matching against typing.Literal function parameters.

With this change, we should be able to support all Literal uses in typeshed
that don't involve enums.


PiperOrigin-RevId: 265144891
  • Loading branch information...
rchen152 committed Aug 23, 2019
1 parent 0a5639e commit 33f1cc5cf5d1ae2343d47e24e7c9964d975209b2
@@ -2162,6 +2162,25 @@ def get_special_attribute(self, node, name, valself):
return super(CallableClass, self).get_special_attribute(node, name, valself)

class LiteralClass(ParameterizedClass):
"""The class of a typing.Literal."""

def __init__(self, base_cls, instance, vm):
formal_type_parameters = {abstract_utils.T: instance.get_class()}
super(LiteralClass, self).__init__(base_cls, formal_type_parameters, vm)
self._instance = instance

def __repr__(self):
return "LiteralClass(%s)" % self._instance

def value(self):
if isinstance(self._instance, AbstractOrConcreteValue):
return self._instance
# TODO(b/123775699): Remove this workaround once we support literal enums.
return None

class PyTDClass(SimpleAbstractValue, mixin.Class):
"""An abstract wrapper for PyTD class objects.
@@ -64,7 +64,12 @@ def contains(subst, annot):
node, param, substs, instantiate_unbound)
for name, param in annot.formal_type_parameters.items()}
# annot may be a subtype of ParameterizedClass, such as TupleClass.
return type(annot)(annot.base_cls, type_parameters, self.vm)
if isinstance(annot, abstract.LiteralClass):
# We can't create a LiteralClass because we don't have a concrete value.
typ = abstract.ParameterizedClass
typ = type(annot)
return typ(annot.base_cls, type_parameters, self.vm)
elif isinstance(annot, abstract.Union):
options = tuple(self.sub_one_annotation(node, o, substs,
@@ -776,11 +776,10 @@ def _constant_to_value(self, pyval, subst, get_node):
template, parameters, subst)
return abstract_class(base_cls, type_parameters, self.vm)
elif isinstance(pyval, pytd.Literal):
# TODO(b/123775699): Create a ParameterizedClass(Literal) to record that
# this type is a literal.
value = self.constant_to_value(
self._get_literal_value(pyval.value), subst, self.vm.root_cfg_node)
return value.get_class()
return abstract.LiteralClass(
self.name_to_value("typing.Literal"), value, self.vm)
elif pyval.__class__ is tuple: # only match raw tuple, not namedtuple/Node
return self.tuple_to_value([self.constant_to_var(item, subst,
@@ -4,6 +4,7 @@

from pytype import abstract
from pytype import abstract_utils
from pytype import compat
from pytype import datatypes
from pytype import function
from pytype import mixin
@@ -13,6 +14,9 @@
from pytype.pytd import pep484
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
from pytype.pytd.parse import parser_constants

import six

log = logging.getLogger(__name__)
@@ -632,6 +636,33 @@ def _match_callable_instance(
return None
return subst

def _match_pyval_against_string(self, pyval, string, subst):
"""Matches a concrete value against a string literal."""
assert isinstance(string, str)

if pyval.__class__ is str: # native str
left_type = "bytes" if self.vm.PY2 else "unicode"
elif isinstance(pyval, compat.BytesType):
left_type = "bytes"
elif isinstance(pyval, compat.UnicodeType):
left_type = "unicode"
return None
# needs to be native str to match `string`
left_value = six.ensure_str(pyval)

right_prefix, right_value = (
if "b" in right_prefix or "u" not in right_prefix and self.vm.PY2:
right_type = "bytes"
right_type = "unicode"
right_value = right_value[1:-1] # remove quotation marks

if left_type == right_type and left_value == right_value:
return subst
return None

def _match_class_and_instance_against_type(
self, left, instance, other_type, subst, node, view):
"""Checks whether an instance of a type is compatible with a (formal) type.
@@ -646,7 +677,25 @@ def _match_class_and_instance_against_type(
A new type parameter assignment if the matching succeeded, None otherwise.
if isinstance(other_type, mixin.Class):
if isinstance(other_type, abstract.LiteralClass):
other_value = other_type.value
if other_value and isinstance(instance, abstract.AbstractOrConcreteValue):
if isinstance(other_value.pyval, str):
return self._match_pyval_against_string(
instance.pyval, other_value.pyval, subst)
return subst if instance.pyval == other_value.pyval else None
elif other_value:
# `instance` does not contain a concrete value. Literal overloads are
# always followed by at least one non-literal fallback, so we should
# fail here.
return None
# TODO(b/123775699): Remove this workaround once we can match against
# literal enums.
return self._match_type_against_type(
instance, other_type.formal_type_parameters[abstract_utils.T],
subst, node, view)
elif isinstance(other_type, mixin.Class):
base = self.match_from_mro(left, other_type)
if base is None:
if other_type.is_protocol:
@@ -830,8 +879,8 @@ def _enforce_common_superclass(self, var):
for cls in classes:
object_in_values |= cls == self.vm.convert.object_type
superclasses = {c.full_name for c in cls.mro}
for compat, name in _COMPATIBLE_BUILTINS:
if compat in superclasses:
for compat_name, name in _COMPATIBLE_BUILTINS:
if compat_name in superclasses:
if common_classes is None:
common_classes = superclasses
@@ -2,7 +2,6 @@

import collections
import hashlib
import re

from pytype import file_utils
from pytype import module_utils
@@ -19,7 +18,6 @@
# Typing members that represent sets of types.
_TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union")
_STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$")

_Params = collections.namedtuple("_", ["required",
@@ -1619,7 +1617,7 @@ def _make_type_type(value):
def _handle_string_literal(value):
if not isinstance(value, str):
return value
match = _STRING_RE.match(value)
match = parser_constants.STRING_RE.match(value)
if not match:
return value
return match.groups()[1][1:-1]
@@ -57,3 +57,6 @@
# Marks external NamedTypes so that they do not get prefixed by the current
# module name.

# Regex for string literals.
STRING_RE = re.compile("^([bu]?)(('[^']*')|(\"[^\"]*\"))$")
@@ -210,8 +210,13 @@ def f(x: bool) -> complex
v2 = foo.f(False)
v3 = foo.f(x)
""", pythonpath=[d.path])
# TODO(b/123775699): Check the inference output.
del ty
self.assertTypesMatchPytd(ty, """
foo: module
x: bool
v1: int
v2: float
v3: complex

def test_pyi_return(self):
with file_utils.Tempdir() as d:
@@ -252,8 +257,7 @@ def test_pyi_typing_extensions(self):
""", pythonpath=[d.path])
self.assertTypesMatchPytd(ty, "foo: module")

# TODO(b/123775699): Include native strings, bytestrings, unicode strings, and
# enums once pytype supports parsing strings and looking up local enums.
# TODO(b/123775699): Include enums once we support looking up local enums.
def test_pyi_value(self):
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
@@ -262,12 +266,18 @@ def test_pyi_value(self):
def f1(x: Literal[True]) -> None: ...
def f2(x: Literal[2]) -> None: ...
def f3(x: Literal[None]) -> None: ...
def f4(x: Literal['hello']) -> None: ...
def f5(x: Literal[b'hello']) -> None: ...
def f6(x: Literal[u'hello']) -> None: ...
import foo
""", pythonpath=[d.path])

def test_pyi_multiple(self):
@@ -283,8 +293,12 @@ def f(x) -> str
v2 = foo.f(None)
v3 = foo.f(True)
""", pythonpath=[d.path])
# TODO(b/123775699): Check the inference output.
del ty
self.assertTypesMatchPytd(ty, """
foo: module
v1: int
v2: int
v3: str

def test_reexport(self):
with file_utils.Tempdir() as d:
@@ -306,14 +320,46 @@ def test_reexport(self):

def test_string(self):
# TODO(b/123775699): test that we do the right thing for string literals,
# not just that we don't barf.
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
from typing import IO, Literal
def open(f: str, mode: Literal["r", "rt"]) -> str: ...
def open(f: str, mode: Literal["rb"]) -> int: ...
ty = self.Infer("""
import foo
def f1(f):
return, mode="r")
def f2(f):
return, mode="rt")
def f3(f):
return, mode="rb")
""", pythonpath=[d.path])
self.assertTypesMatchPytd(ty, """
foo: module
def f1(f) -> str: ...
def f2(f) -> str: ...
def f3(f) -> int: ...

def test_unknown(self):
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
from typing import Literal
def open(f: str, mode: Literal["r"]) -> None: ...
def f(x: Literal[True]) -> int: ...
def f(x: Literal[False]) -> str: ...
ty = self.Infer("""
import foo
v = foo.f(__any_object__)
""", pythonpath=[d.path])
# Inference completing without type errors shows that `__any_object__`
# matched both Literal[True] and Literal[False].
self.assertTypesMatchPytd(ty, """
from typing import Any
foo: module
v: Any
self.Check("import foo", pythonpath=[d.path])

test_base.main(globals(), __name__ == "__main__")

0 comments on commit 33f1cc5

Please sign in to comment.
You can’t perform that action at this time.