Skip to content

Commit

Permalink
Protocol attribute matching fix #4.
Browse files Browse the repository at this point in the history
Who knew there were so many different ways to arrange three for loops?

PiperOrigin-RevId: 387919695
  • Loading branch information
rchen152 committed Aug 3, 2021
1 parent 5bc9c3b commit 2906c81
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
39 changes: 23 additions & 16 deletions pytype/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,16 +1030,18 @@ def _match_protocol_attribute(self, left, other_type, attribute, subst, view):
# TODO(rechen): Even if other_type isn't parameterized, we should run
# _match_protocol_attribute to catch mismatches in method signatures.
return subst
# The entire match succeeds if left_attribute matches *any* binding of
# protocol_attribute_var. A binding matches if *any* options for
# left_attribute match *all* options for the binding's types.
bad_matches = []
for protocol_attribute in protocol_attribute_var.data:
protocol_attribute_types = list(
self._get_attribute_types(other_type, protocol_attribute))
for new_view in abstract_utils.get_views([left_attribute], self._node):
new_view.update(view)
new_substs = []
# Every binding of left_attribute needs to match at least one binding of
# protocol_attribute_var.
new_substs = []
for new_view in abstract_utils.get_views([left_attribute], self._node):
new_view.update(view)
bad_matches = []
for protocol_attribute in protocol_attribute_var.data:
# For this binding of left_attribute to match this binding of
# protocol_attribute_var, *all* options in protocol_attribute_types need
# to match.
protocol_attribute_types = list(
self._get_attribute_types(other_type, protocol_attribute))
for protocol_attribute_type in protocol_attribute_types:
match_result = self.match_var_against_type(
left_attribute, protocol_attribute_type, subst, new_view)
Expand All @@ -1050,12 +1052,17 @@ def _match_protocol_attribute(self, left, other_type, attribute, subst, view):
else:
new_substs.append(match_result)
else:
return self._merge_substs(subst, new_substs)
bad_left, bad_right = zip(*bad_matches)
self._protocol_error = ProtocolTypeError(
left_cls, other_type, attribute, self.vm.merge_values(bad_left),
self.vm.merge_values(bad_right))
return None
# We've successfully matched all options in protocol_attribute_types.
break
else:
# This binding of left_attribute has not matched any binding of
# protocol_attribute_var.
bad_left, bad_right = zip(*bad_matches)
self._protocol_error = ProtocolTypeError(
left_cls, other_type, attribute, self.vm.merge_values(bad_left),
self.vm.merge_values(bad_right))
return None
return self._merge_substs(subst, new_substs)

def _get_concrete_values_and_classes(self, var):
# TODO(rechen): For type parameter instances, we should extract the concrete
Expand Down
24 changes: 24 additions & 0 deletions pytype/tests/py3/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,5 +1119,29 @@ class Baz:
"e": (r"expected Dict\[str, List\[int\]\], "
r"got Dict\[str, List\[str\]\]")})

def test_match_multi_attributes_against_dataclass_protocol(self):
errors = self.CheckWithErrors("""
from typing import Dict, Protocol, TypeVar, Union
import dataclasses
T = TypeVar('T')
class Dataclass(Protocol[T]):
__dataclass_fields__: Dict[str, dataclasses.Field[T]]
def f(x: Dataclass[int]):
pass
@dataclasses.dataclass
class ShouldMatch:
x: int
y: int
@dataclasses.dataclass
class ShouldNotMatch:
x: int
y: str
f(ShouldMatch(0, 0))
f(ShouldNotMatch(0, '')) # wrong-arg-types[e]
""")
self.assertErrorRegexes(errors, {
"e": (r"expected Dict\[str, dataclasses\.Field\[int\]\], "
r"got Dict\[str, dataclasses\.Field\[Union\[int, str\]\]\]")})


test_base.main(globals(), __name__ == "__main__")

0 comments on commit 2906c81

Please sign in to comment.