Skip to content

Commit

Permalink
Fix some corner cases in invalidating an enum tracker.
Browse files Browse the repository at this point in the history
Tracks the type of every MATCH_ opcode in a case branch; if we have anything
beyond a simple CMP we stop tracking enums for exhaustiveness.

PiperOrigin-RevId: 576982294
  • Loading branch information
martindemello authored and rchen152 committed Oct 27, 2023
1 parent 89cd610 commit c452719
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
30 changes: 28 additions & 2 deletions pytype/pattern_matching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Support for pattern matching."""

import collections
import enum

from typing import Dict, List, Optional, Set, Tuple, cast

Expand All @@ -18,6 +19,23 @@
_MatchSuccessType = Optional[bool]


class _MatchTypes(enum.Enum):
"""Track match types based on generated opcode."""

CLASS = enum.auto()
SEQUENCE = enum.auto()
KEYS = enum.auto()
MAPPING = enum.auto()
CMP = enum.auto()

@classmethod
def make(cls, op: opcodes.Opcode):
if op.name.startswith("MATCH_"):
return cls[op.name[len("MATCH_"):]]
else:
return cls.CMP


class _Matches:
"""Tracks branches of match statements."""

Expand Down Expand Up @@ -133,7 +151,8 @@ def __init__(self, ast_matches, ctx):
self._enum_tracker = {}
self._type_tracker: Dict[int, Dict[int, _TypeTracker]] = (
collections.defaultdict(dict))
self._match_types = {}
self._match_types: Dict[int, Set[_MatchTypes]] = (
collections.defaultdict(set))
self._active_ends = set()
# If we analyse the same match statement twice, the second time around we
# should not do exhaustiveness and redundancy checks since we have already
Expand Down Expand Up @@ -166,7 +185,8 @@ def _get_enum_tracker(
if match_line not in self._enum_tracker:
self._add_new_enum_match(match_val, match_line)
enum_tracker = self._enum_tracker[match_line]
if match_val.cls != enum_tracker.enum_cls:
if (match_val.cls != enum_tracker.enum_cls or
self._match_types[match_line] != {_MatchTypes.CMP}):
# We are matching a tuple or structure with different enums in it.
enum_tracker.invalidate()
return None
Expand Down Expand Up @@ -225,6 +245,12 @@ def is_current_as_name(self, op: opcodes.Opcode, name: str):
return None
return self.matches.as_names.get(op.line) == name

def register_match_type(self, op: opcodes.Opcode):
if op.line not in self.matches.match_cases:
return
match_line = self.matches.match_cases[op.line]
self._match_types[match_line].add(_MatchTypes.make(op))

def _add_enum_branch(
self,
op: opcodes.Opcode,
Expand Down
22 changes: 22 additions & 0 deletions pytype/tests/test_pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,28 @@ def f(a: A, b: B):
print('bar')
""")

def test_enum_in_tuple(self):
"""Skip tracking if matching an enum in a tuple."""
# Python unpacks the tuple and compiles to a simple enum cmp in some cases.
# Check that we do not track exhaustive or redundant matches for this case.
self.Check("""
import enum
class Side(enum.Enum):
RIGHT = enum.auto()
LEFT = enum.auto()
CUSTOM = enum.auto()
def actuate_phrase(side: Side, assistant: bool):
match (side, assistant):
case (Side.LEFT | Side.RIGHT, _):
return 'preset side'
case (Side.CUSTOM, True):
return 'custom true'
case (Side.CUSTOM, False): # should not be redundant
return 'custom false'
""")

def test_pytd_enum_basic(self):
with self.DepTree([("foo.pyi", """
import enum
Expand Down
17 changes: 6 additions & 11 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, ctx):
# Note that we don't need to scope this to the frame because we don't reuse
# variable ids.
self._var_names = {}
self._branch_tracker = None
self._branch_tracker: pattern_matching.BranchTracker = None

# Locals attached to the block graph
self.block_env = block_environment.Environment()
Expand Down Expand Up @@ -206,7 +206,6 @@ def is_at_maximum_depth(self):

def _is_match_case_op(self, op):
"""Should we handle case matching for this opcode."""
assert self._branch_tracker is not None
# A case statement generates multiple opcodes on the same line. Since the
# director matches on line numbers, we only trigger the case handler on a
# specific opcode (which varies depending on the type of match)
Expand All @@ -225,7 +224,6 @@ def _is_match_case_op(self, op):

def _handle_match_case(self, state, op):
"""Track type narrowing and default cases in a match statement."""
assert self._branch_tracker is not None
if not self._is_match_case_op(op):
return state
if op.line in self._branch_tracker.matches.defaults:
Expand Down Expand Up @@ -265,7 +263,6 @@ def run_instruction(
Raises:
VirtualMachineError: if a fatal error occurs.
"""
assert self._branch_tracker is not None
_opcode_counter.inc(op.name)
self.frame.current_opcode = op
self._importing = "IMPORT" in op.__class__.__name__
Expand Down Expand Up @@ -979,7 +976,6 @@ def _get_value_from_annotations(self, state, op, name, local, orig_val):
def _pop_and_store(self, state, op, name, local):
"""Pop a value off the stack and store it in a variable."""
state, orig_val = state.pop()
assert self._branch_tracker is not None
if (self._branch_tracker.is_current_as_name(op, name) and
self._branch_tracker.get_current_type_tracker(op, orig_val)):
# If we are storing the as name in a case match, i.e.
Expand Down Expand Up @@ -1156,7 +1152,6 @@ def _handle_311_pattern_match_on_dict(self, state, op, obj, ret):
# DELETE_SUBSCR on the concrete keys and binds the remaining dict to `rest`
# (3.10 had a specific COPY_DICT_WITHOUT_KEYS opcode to handle this but it
# was removed in 3.11).
assert self._branch_tracker is not None
if not (self.ctx.python_version == (3, 11) and
op.line in self._branch_tracker.matches.match_cases):
return state
Expand Down Expand Up @@ -1828,8 +1823,8 @@ def _replace_abstract_exception(self, state, exc_type):

def _compare_op(self, state, op_arg, op):
"""Pops and compares the top two stack values and pushes a boolean."""
assert self._branch_tracker is not None
state, (x, y) = state.popn(2)
self._branch_tracker.register_match_type(op)
match_enum = self._branch_tracker.add_cmp_branch(op, x, y)
if match_enum is not None:
# The match always succeeds/fails.
Expand Down Expand Up @@ -3103,15 +3098,15 @@ def byte_GET_LEN(self, state, op):
return state.push(length.instantiate(state.node))

def byte_MATCH_MAPPING(self, state, op):
del op
self._branch_tracker.register_match_type(op)
obj_var = state.top()
is_map = vm_utils.match_mapping(state.node, obj_var, self.ctx)
ret = self.ctx.convert.bool_values[is_map]
log.debug("match_mapping: %r", ret)
return state.push(ret.to_variable(state.node))

def byte_MATCH_SEQUENCE(self, state, op):
del op
self._branch_tracker.register_match_type(op)
obj_var = state.top()
is_seq = vm_utils.match_sequence(obj_var)
ret = self.ctx.convert.bool_values[is_seq]
Expand All @@ -3120,7 +3115,7 @@ def byte_MATCH_SEQUENCE(self, state, op):

def byte_MATCH_KEYS(self, state, op):
"""Implementation of the MATCH_KEYS opcode."""
del op
self._branch_tracker.register_match_type(op)
obj_var, keys_var = state.topn(2)
ret = vm_utils.match_keys(state.node, obj_var, keys_var, self.ctx)
vals = ret or self.ctx.convert.none.to_variable(state.node)
Expand All @@ -3143,6 +3138,7 @@ def _store_local_or_cellvar(self, state, name, var):
def byte_MATCH_CLASS(self, state, op):
"""Implementation of the MATCH_CLASS opcode."""
# NOTE: 3.10 specific; stack effects change somewhere en route to 3.12
self._branch_tracker.register_match_type(op)
posarg_count = op.arg
state, keys_var = state.pop()
state, (obj_var, cls_var) = state.popn(2)
Expand All @@ -3153,7 +3149,6 @@ def byte_MATCH_CLASS(self, state, op):
success = ret.success
vals = ret.values or self.ctx.convert.none.to_variable(state.node)
if ret.matched:
assert self._branch_tracker is not None
# Narrow the type of the match variable since we are in a case branch
# where it has matched the given class. The branch tracker will store the
# original (unnarrowed) type, since the new variable shadows it.
Expand Down

0 comments on commit c452719

Please sign in to comment.