Skip to content

Commit

Permalink
[eggex refactor] Use sum type for eggex_ops
Browse files Browse the repository at this point in the history
This will reduce allocations in the BASH_REMATCH case.

We don't pay for what we're not using.
  • Loading branch information
Andy C committed Dec 19, 2023
1 parent 9d93286 commit d89e2be
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 51 deletions.
58 changes: 38 additions & 20 deletions builtin/func_eggex.py
Expand Up @@ -5,7 +5,8 @@
from __future__ import print_function

from _devbuild.gen.syntax_asdl import loc_t
from _devbuild.gen.value_asdl import (value, value_e, value_t, regex_match_e,
from _devbuild.gen.value_asdl import (value, value_e, value_t, eggex_ops,
eggex_ops_e, eggex_ops_t, regex_match_e,
RegexMatch)
from core import error
from core import state
Expand Down Expand Up @@ -56,27 +57,31 @@ def _ReturnValue(self, s, indices, i, convert_func, blame_loc):
return val
else:
assert num_groups != 0
msg = 'Expected capture group less than %d, got %d' % (num_groups,
i)
raise error.Expr(msg, blame_loc)
raise error.Expr(
'Expected capture group less than %d, got %d' %
(num_groups, i), blame_loc)

def _Call(self, match, group_arg, blame_loc):
# type: (RegexMatch, value_t, loc_t) -> value_t
group_index = _GetGroupIndex(group_arg, match.capture_names,
blame_loc)

group_index = _GetGroupIndex(group_arg, match.ops, blame_loc)

convert_func = None # type: Optional[value_t]
if len(match.convert_funcs): # for ERE string, it's []
if group_index != 0: # doesn't have a name or type attached to it
convert_func = match.convert_funcs[group_index - 1]

with tagswitch(match.ops) as case:
if case(eggex_ops_e.Yes):
ops = cast(eggex_ops.Yes, match.ops)

# group 0 doesn't have a name or type attached to it
if len(ops.convert_funcs) and group_index != 0:
convert_func = ops.convert_funcs[group_index - 1]

return self._ReturnValue(match.s, match.indices, group_index,
convert_func, blame_loc)



def _GetGroupIndex(group, capture_names, blame_loc):
# type: (value_t, List[Optional[str]], loc_t) -> int
def _GetGroupIndex(group, ops, blame_loc):
# type: (value_t, eggex_ops_t, loc_t) -> int
UP_group = group

with tagswitch(group) as case:
Expand All @@ -86,13 +91,26 @@ def _GetGroupIndex(group, capture_names, blame_loc):

elif case(value_e.Str):
group = cast(value.Str, UP_group)
group_index = -1
for i, name in enumerate(capture_names):
if name == group.s:
group_index = i + 1 # 1-based
break
if group_index == -1:
raise error.Expr('No such group %r' % group.s, blame_loc)

UP_ops = ops
with tagswitch(ops) as case2:

if case2(eggex_ops_e.No):
raise error.Expr(
"ERE captures don't have names (%r)" % group.s,
blame_loc)

elif case2(eggex_ops_e.Yes):
ops = cast(eggex_ops.Yes, UP_ops)
group_index = -1
for i, name in enumerate(ops.capture_names):
if name == group.s:
group_index = i + 1 # 1-based
break
if group_index == -1:
raise error.Expr('No such group %r' % group.s,
blame_loc)

else:
# TODO: add method name to this error
raise error.TypeErr(group, 'expected Int or Str', blame_loc)
Expand Down Expand Up @@ -123,7 +141,7 @@ def Call(self, rd):
group_arg = rd.PosValue()
rd.Done()

match = self.mem.GetRegexIndices()
match = self.mem.GetRegexMatch()
UP_match = match
with tagswitch(match) as case:
if case(regex_match_e.No):
Expand Down
14 changes: 7 additions & 7 deletions builtin/method_str.py
Expand Up @@ -2,8 +2,8 @@

from __future__ import print_function

from _devbuild.gen.value_asdl import (value, value_e, value_t, RegexMatch)

from _devbuild.gen.value_asdl import (value, value_e, value_t, eggex_ops,
eggex_ops_t, RegexMatch)
from core import error
from core import vm
from frontend import typed_args
Expand Down Expand Up @@ -96,14 +96,14 @@ def Call(self, rd):
# lazily converts to ERE
ere = regex_translate.AsPosixEre(eggex_val)
cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
capture_names = eggex_val.capture_names
convert_funcs = eggex_val.convert_funcs
capture = eggex_ops.Yes(
eggex_val.convert_funcs,
eggex_val.capture_names) # type: eggex_ops_t

elif case(value_e.Str):
ere = cast(value.Str, pattern).s
cflags = 0
capture_names = []
convert_funcs = []
capture = eggex_ops.No

else:
# TODO: add method name to this error
Expand All @@ -129,4 +129,4 @@ def Call(self, rd):
if indices is None:
return value.Null

return RegexMatch(string, indices, convert_funcs, capture_names)
return RegexMatch(string, indices, capture)
10 changes: 3 additions & 7 deletions core/state.py
Expand Up @@ -2153,15 +2153,11 @@ def IsGlobalScope(self):
# type: () -> bool
return len(self.var_stack) == 1

def ClearRegexIndices(self):
# type: () -> None
self.regex_match[-1] = regex_match.No

def SetRegexIndices(self, match):
# type: (RegexMatch) -> None
def SetRegexMatch(self, match):
# type: (regex_match_t) -> None
self.regex_match[-1] = match

def GetRegexIndices(self):
def GetRegexMatch(self):
# type: () -> regex_match_t
return self.regex_match[-1]

Expand Down
11 changes: 7 additions & 4 deletions core/value.asdl
Expand Up @@ -44,10 +44,13 @@ module value
| Indexed(str name, int index, loc blame_loc)
| Keyed(str name, str key, loc blame_loc)

RegexMatch = (
str s, List[int] indices,
# TODO: These 2 fields could be None for BASH_REMATCH
List[value?] convert_funcs, List[str?] capture_names)
eggex_ops =
# for BASH_REMATCH or ~ with a string
No
# These lists are indexed by group number, and will have None entries
| Yes(List[value?] convert_funcs, List[str?] capture_names)

RegexMatch = (str s, List[int] indices, eggex_ops ops)

regex_match =
No
Expand Down
11 changes: 6 additions & 5 deletions osh/sh_expr_eval.py
Expand Up @@ -10,8 +10,7 @@
"""

from _devbuild.gen.id_kind_asdl import Id
from _devbuild.gen.runtime_asdl import (
scope_t, )
from _devbuild.gen.runtime_asdl import scope_t
from _devbuild.gen.syntax_asdl import (
word_t,
CompoundWord,
Expand Down Expand Up @@ -41,6 +40,8 @@
sh_lvalue_e,
sh_lvalue_t,
LeftName,
eggex_ops,
regex_match,
RegexMatch,
)
from core import alloc
Expand Down Expand Up @@ -1078,11 +1079,11 @@ def EvalB(self, node):
e_die_status(2, e.message, loc.Word(node.right))

if indices is not None:
self.mem.SetRegexIndices(
RegexMatch(s1, indices, [], []))
self.mem.SetRegexMatch(
RegexMatch(s1, indices, eggex_ops.No))
return True
else:
self.mem.ClearRegexIndices()
self.mem.SetRegexMatch(regex_match.No)
return False

if op_id == Id.Op_Less:
Expand Down
7 changes: 7 additions & 0 deletions test/ysh-runtime-errors.sh
Expand Up @@ -385,7 +385,14 @@ EOF

test-eggex-api() {
_expr-error-case '= _group(0)' # No groups

_expr-error-case 'if ("foo" ~ /[a-z]/) { echo $[_group(1)] }'
_expr-error-case 'if ("foo" ~ /[a-z]/) { echo $[_group("name")] }'

# ERE
_expr-error-case 'if ("foo" ~ "[a-z]") { echo $[_group(1)] }'
_expr-error-case 'if ("foo" ~ "[a-z]") { echo $[_group("name")] }'

_expr-error-case '= _group("foo")' # No such group
}

Expand Down
14 changes: 6 additions & 8 deletions ysh/val_ops.py
Expand Up @@ -5,7 +5,8 @@
from __future__ import print_function

from _devbuild.gen.syntax_asdl import loc, loc_t, command_t
from _devbuild.gen.value_asdl import (value, value_e, value_t, RegexMatch)
from _devbuild.gen.value_asdl import (value, value_e, value_t, eggex_ops,
eggex_ops_t, regex_match, RegexMatch)
from core import error
from core import ui
from mycpp.mylib import tagswitch
Expand Down Expand Up @@ -447,16 +448,14 @@ def MatchRegex(left, right, mem):

right_s = right.s
regex_flags = 0
capture_names = [] # type: List[Optional[str]]
convert_funcs = [] # type: List[Optional[value_t]]
capture = eggex_ops.No # type: eggex_ops_t

elif case(value_e.Eggex):
right = cast(value.Eggex, UP_right)

right_s = regex_translate.AsPosixEre(right)
regex_flags = regex_translate.LibcFlags(right.canonical_flags)
capture_names = right.capture_names
convert_funcs = right.convert_funcs
capture = eggex_ops.Yes(right.convert_funcs, right.capture_names)

else:
raise error.TypeErr(right, 'Expected Str or Regex for RHS of ~',
Expand All @@ -474,12 +473,11 @@ def MatchRegex(left, right, mem):
indices = libc.regex_search(right_s, regex_flags, left_s, 0)
if indices is not None:
if mem:
mem.SetRegexIndices(
RegexMatch(left_s, indices, convert_funcs, capture_names))
mem.SetRegexMatch(RegexMatch(left_s, indices, capture))
return True
else:
if mem:
mem.ClearRegexIndices()
mem.SetRegexMatch(regex_match.No)
return False


Expand Down

0 comments on commit d89e2be

Please sign in to comment.