Skip to content

Commit

Permalink
[eggex refactor] Use regex_match_t sum type for global match
Browse files Browse the repository at this point in the history
This will make it easier to add better location info.  And it may reduce
allocations.
  • Loading branch information
Andy C committed Dec 19, 2023
1 parent 668eb9e commit e31c87c
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 66 deletions.
40 changes: 26 additions & 14 deletions builtin/func_eggex.py
Expand Up @@ -5,7 +5,8 @@
from __future__ import print_function

from _devbuild.gen.syntax_asdl import loc_t
from _devbuild.gen.value_asdl import value, value_e, value_t
from _devbuild.gen.value_asdl import (value, value_e, value_t, regex_match_e,
RegexMatch)
from core import error
from core import state
from core import vm
Expand Down Expand Up @@ -54,11 +55,9 @@ def _GetMatch(self, s, indices, i, convert_func, blame_loc):
blame_loc)
return val
else:
if num_groups == 0:
msg = 'No regex capture groups'
else:
msg = 'Expected capture group less than %d, got %d' % (
num_groups, i)
assert num_groups != 0
msg = 'Expected capture group less than %d, got %d' % (num_groups,
i)
raise error.Expr(msg, blame_loc)


Expand Down Expand Up @@ -110,16 +109,29 @@ def Call(self, rd):
group = rd.PosValue()
rd.Done()

s, indices, capture_names, convert_funcs = self.mem.GetRegexIndices()
group_index = _GetGroupIndex(group, capture_names, rd.LeftParenToken())
match = self.mem.GetRegexIndices()
UP_match = match
with tagswitch(match) as case:
if case(regex_match_e.No):
# _group(0) etc. is illegal
raise error.Expr('No regex capture groups',
rd.LeftParenToken())

convert_func = None # type: Optional[value_t]
if len(convert_funcs): # for ERE string, it's []
if group_index != 0: # doesn't have a name or type attached to it
convert_func = convert_funcs[group_index - 1]
elif case(regex_match_e.Yes):
match = cast(RegexMatch, UP_match)

return self._GetMatch(s, indices, group_index, convert_func,
rd.LeftParenToken())
group_index = _GetGroupIndex(group, match.capture_names,
rd.LeftParenToken())

convert_func = None # type: Optional[value_t]
if len(match.convert_funcs): # for ERE string, it's []
if group_index != 0: # doesn't have a name or type attached to it
convert_func = match.convert_funcs[group_index - 1]

return self._GetMatch(match.s, match.indices, group_index,
convert_func, rd.LeftParenToken())

raise AssertionError()


class MatchMethod(_MatchCallable):
Expand Down
4 changes: 2 additions & 2 deletions builtin/method_str.py
Expand Up @@ -2,7 +2,7 @@

from __future__ import print_function

from _devbuild.gen.value_asdl import (value, value_e, value_t)
from _devbuild.gen.value_asdl import (value, value_e, value_t, RegexMatch)

from core import error
from core import vm
Expand Down Expand Up @@ -129,4 +129,4 @@ def Call(self, rd):
if indices is None:
return value.Null

return value.Match(string, indices, convert_funcs, capture_names)
return RegexMatch(string, indices, convert_funcs, capture_names)
54 changes: 18 additions & 36 deletions core/state.py
Expand Up @@ -17,7 +17,8 @@
from _devbuild.gen.types_asdl import opt_group_i
from _devbuild.gen.value_asdl import (value, value_e, value_t, sh_lvalue,
sh_lvalue_e, sh_lvalue_t, LeftName,
y_lvalue_e)
y_lvalue_e, regex_match, regex_match_e,
regex_match_t, RegexMatch)
from asdl import runtime
from core import error
from core.error import e_usage, e_die
Expand Down Expand Up @@ -969,10 +970,7 @@ def __init__(self, mem):
mem.pipe_status.append([])
mem.process_sub_status.append([])

mem.regex_indices.append([])
mem.regex_string.append('')
mem.capture_names.append([])
mem.convert_funcs.append([])
mem.regex_match.append(regex_match.No)

self.mem = mem

Expand All @@ -982,10 +980,7 @@ def __enter__(self):

def __exit__(self, type, value, traceback):
# type: (Any, Any, Any) -> None
self.mem.convert_funcs.pop()
self.mem.capture_names.pop()
self.mem.regex_string.pop()
self.mem.regex_indices.pop()
self.mem.regex_match.pop()

self.mem.process_sub_status.pop()
self.mem.pipe_status.pop()
Expand Down Expand Up @@ -1073,12 +1068,7 @@ def __init__(self, dollar0, argv, arena, debug_stack):

# A stack but NOT a register?
self.this_dir = [] # type: List[str]

# 0 is the whole match, 1..n are submatches
self.regex_indices = [[]] # type: List[List[int]]
self.regex_string = [''] # type: List[str]
self.capture_names = [[]] # type: List[List[Optional[str]]]
self.convert_funcs = [[]] # type: List[List[Optional[value_t]]]
self.regex_match = [regex_match.No] # type: List[regex_match_t]

self.last_bg_pid = -1 # Uninitialized value mutable public variable

Expand Down Expand Up @@ -1830,8 +1820,13 @@ def GetValue(self, name, which_scopes=scope_e.Shopt):
return value.List(items)

if name == 'BASH_REMATCH':
groups = util.RegexGroups(self.regex_string[-1],
self.regex_indices[-1])
top_match = self.regex_match[-1]
with tagswitch(top_match) as case:
if case(regex_match_e.No):
groups = []
elif case(regex_match_e.Yes):
match = cast(RegexMatch, top_match)
groups = util.RegexGroups(match.s, match.indices)
return value.BashArray(groups)

# Do lookup of system globals before looking at user variables. Note: we
Expand Down Expand Up @@ -2160,28 +2155,15 @@ def IsGlobalScope(self):

def ClearRegexIndices(self):
# type: () -> None
indices = self.regex_indices[-1]
del indices[:] # no clear() in Python 2

self.regex_string[-1] = ''

names = self.capture_names[-1]
del names[:]

funcs = self.convert_funcs[-1]
del funcs[:]
self.regex_match[-1] = regex_match.No

def SetRegexIndices(self, s, indices, capture_names, convert_funcs):
# type: (str, List[int], List[Optional[str]], List[Optional[value_t]]) -> None
self.regex_string[-1] = s
self.regex_indices[-1] = indices
self.capture_names[-1] = capture_names
self.convert_funcs[-1] = convert_funcs
def SetRegexIndices(self, match):
# type: (RegexMatch) -> None
self.regex_match[-1] = match

def GetRegexIndices(self):
# type: () -> Tuple[str, List[int], List[Optional[str]], List[Optional[value_t]]]
return (self.regex_string[-1], self.regex_indices[-1],
self.capture_names[-1], self.convert_funcs[-1])
# type: () -> regex_match_t
return self.regex_match[-1]


#
Expand Down
12 changes: 10 additions & 2 deletions core/value.asdl
Expand Up @@ -44,6 +44,15 @@ module value
| Indexed(str name, int index, loc blame_loc)
| Keyed(str name, str key, loc blame_loc)

RegexMatch = (
str s, List[int] indices,
# TODO: These 2 fields could be None for BASH_REMATCH
List[value?] convert_funcs, List[str?] capture_names)

regex_match =
No
| Yes %RegexMatch

# Commands, words, and expressions from syntax.asdl are evaluated to a VALUE.
# value_t instances are stored in state.Mem().
value =
Expand Down Expand Up @@ -83,8 +92,7 @@ module value
# match, and each group has both a start and end index.
# It's flat to reduce allocations. The group() start() end() funcs/methods
# provide a nice interface.
| Match(str s, List[int] indices,
List[value?] convert_funcs, List[str?] capture_names)
| Match %RegexMatch

# ^[42 + a[i]]
| Expr(expr e)
Expand Down
8 changes: 4 additions & 4 deletions frontend/typed_args.py
Expand Up @@ -4,7 +4,7 @@
from _devbuild.gen.runtime_asdl import cmd_value
from _devbuild.gen.syntax_asdl import (loc, loc_t, ArgList, LiteralBlock,
command_t, expr_t)
from _devbuild.gen.value_asdl import (value, value_e, value_t)
from _devbuild.gen.value_asdl import (value, value_e, value_t, RegexMatch)
from core import error
from core.error import e_usage
from frontend import location
Expand Down Expand Up @@ -236,9 +236,9 @@ def _ToPlace(self, val):
self.BlamePos())

def _ToMatch(self, val):
# type: (value_t) -> value.Match
# type: (value_t) -> RegexMatch
if val.tag() == value_e.Match:
return cast(value.Match, val)
return cast(RegexMatch, val)

raise error.TypeErr(val,
'Arg %d should be a Match' % self.pos_consumed,
Expand Down Expand Up @@ -347,7 +347,7 @@ def PosEggex(self):
return self._ToEggex(val)

def PosMatch(self):
# type: () -> value.Match
# type: () -> RegexMatch
val = self.PosValue()
return self._ToMatch(val)

Expand Down
2 changes: 1 addition & 1 deletion osh/cmd_eval.py
Expand Up @@ -1486,7 +1486,7 @@ def _DoCase(self, node):
eggex = cast(Eggex, case_arm.pattern)
eggex_val = self.expr_ev.EvalEggex(eggex)

if val_ops.RegexMatch(to_match, eggex_val, self.mem):
if val_ops.MatchRegex(to_match, eggex_val, self.mem):
status = self._ExecuteList(case_arm.action)
matched = True
break
Expand Down
4 changes: 3 additions & 1 deletion osh/sh_expr_eval.py
Expand Up @@ -41,6 +41,7 @@
sh_lvalue_e,
sh_lvalue_t,
LeftName,
RegexMatch,
)
from core import alloc
from core import error
Expand Down Expand Up @@ -1077,7 +1078,8 @@ def EvalB(self, node):
e_die_status(2, e.message, loc.Word(node.right))

if indices is not None:
self.mem.SetRegexIndices(s1, indices, [], [])
self.mem.SetRegexIndices(
RegexMatch(s1, indices, [], []))
return True
else:
self.mem.ClearRegexIndices()
Expand Down
7 changes: 4 additions & 3 deletions ysh/expr_eval.py
Expand Up @@ -369,7 +369,8 @@ def CallConvertFunc(self, func_val, arg, blame_loc):
except error.FatalRuntime as e:
# TODO: it needs a name
# This blames the group() call
self.errfmt.Print_('Fatal error calling Eggex conversion func', blame_loc)
self.errfmt.Print_('Fatal error calling Eggex conversion func',
blame_loc)
raise

return val
Expand Down Expand Up @@ -762,11 +763,11 @@ def _EvalCompare(self, node):
else:
try:
if op.id == Id.Arith_Tilde:
result = val_ops.RegexMatch(left, right, self.mem)
result = val_ops.MatchRegex(left, right, self.mem)

elif op.id == Id.Expr_NotTilde:
# don't pass self.mem to not set a match
result = not val_ops.RegexMatch(left, right, None)
result = not val_ops.MatchRegex(left, right, None)

else:
raise AssertionError(op)
Expand Down
7 changes: 4 additions & 3 deletions ysh/val_ops.py
Expand Up @@ -5,7 +5,7 @@
from __future__ import print_function

from _devbuild.gen.syntax_asdl import loc, loc_t, command_t
from _devbuild.gen.value_asdl import (value, value_e, value_t)
from _devbuild.gen.value_asdl import (value, value_e, value_t, RegexMatch)
from core import error
from core import ui
from mycpp.mylib import tagswitch
Expand Down Expand Up @@ -433,7 +433,7 @@ def Contains(needle, haystack):
return False


def RegexMatch(left, right, mem):
def MatchRegex(left, right, mem):
# type: (value_t, value_t, Optional[state.Mem]) -> bool
"""
Args:
Expand Down Expand Up @@ -474,7 +474,8 @@ def RegexMatch(left, right, mem):
indices = libc.regex_search(right_s, regex_flags, left_s, 0)
if indices is not None:
if mem:
mem.SetRegexIndices(left_s, indices, capture_names, convert_funcs)
mem.SetRegexIndices(
RegexMatch(left_s, indices, convert_funcs, capture_names))
return True
else:
if mem:
Expand Down

0 comments on commit e31c87c

Please sign in to comment.