Skip to content

Commit

Permalink
[eggex refactor] Consolidate MatchFunc and MatchMethod
Browse files Browse the repository at this point in the history
Also fix a crash bug.  An eggex like

    / d+ ; ignorecase ; PCRE /

would crash when used because there are no canonical_flags.
  • Loading branch information
Andy C committed Dec 16, 2023
1 parent 1124655 commit 7b76121
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 50 deletions.
32 changes: 29 additions & 3 deletions builtin/func_eggex.py
Expand Up @@ -18,7 +18,7 @@
E = 2 # _end()


def GetMatch(s, indices, i, to_return, blame_loc):
def _GetMatch(s, indices, i, to_return, blame_loc):
# type: (str, List[int], int, int, loc_t) -> value_t
num_groups = len(indices) / 2 # including group 0
if i < num_groups:
Expand All @@ -43,7 +43,7 @@ def GetMatch(s, indices, i, to_return, blame_loc):
raise error.UserError(2, msg, blame_loc)


class MatchAccess(vm._Callable):
class MatchFunc(vm._Callable):
"""
_group(0) or _group() : get the whole match
_group(1) to _group(N): get a submatch
Expand All @@ -66,7 +66,33 @@ def Call(self, rd):

s, indices = self.mem.GetRegexIndices()

return GetMatch(s, indices, i, self.to_return, rd.LeftParenToken())
return _GetMatch(s, indices, i, self.to_return, rd.LeftParenToken())


class MatchMethod(vm._Callable):
"""
m => group(i)
m => start(i)
m => end(i)
"""

def __init__(self, to_return):
# type: (int) -> None
self.to_return = to_return

def Call(self, rd):
# type: (typed_args.Reader) -> value_t

# This is guaranteed
m = rd.PosMatch()
# TODO: Support strings for named captures
i = rd.OptionalInt(default_=0)
#val = rd.PosValue()

rd.Done()

return _GetMatch(m.s, m.indices, i, self.to_return,
rd.LeftParenToken())


# vim: sw=4
22 changes: 0 additions & 22 deletions builtin/method_other.py
Expand Up @@ -4,7 +4,6 @@

from _devbuild.gen.value_asdl import (value, value_t)

from builtin import func_eggex
from core import state
from core import vm
from frontend import typed_args
Expand All @@ -31,24 +30,3 @@ def Call(self, rd):
self.mem.SetPlace(place, val, rd.LeftParenToken())

return value.Null


class MatchAccess(vm._Callable):

def __init__(self, to_return):
# type: (int) -> None
self.to_return = to_return

def Call(self, rd):
# type: (typed_args.Reader) -> value_t

# This is guaranteed
m = rd.PosMatch()
# TODO: Support strings for named captures
i = rd.OptionalInt(default_=0)
#val = rd.PosValue()

rd.Done()

return func_eggex.GetMatch(m.s, m.indices, i, self.to_return,
rd.LeftParenToken())
12 changes: 6 additions & 6 deletions core/shell.py
Expand Up @@ -773,9 +773,9 @@ def Main(
}

methods[value_e.Match] = {
'group': method_other.MatchAccess(func_eggex.G),
'start': method_other.MatchAccess(func_eggex.S),
'end': method_other.MatchAccess(func_eggex.E),
'group': func_eggex.MatchMethod(func_eggex.G),
'start': func_eggex.MatchMethod(func_eggex.S),
'end': func_eggex.MatchMethod(func_eggex.E),
}

methods[value_e.IO] = {
Expand Down Expand Up @@ -815,9 +815,9 @@ def Main(
_SetGlobalFunc(mem, 'len', func_misc.Len())

# TODO: rename to group
_SetGlobalFunc(mem, '_match', func_eggex.MatchAccess(mem, func_eggex.G))
_SetGlobalFunc(mem, '_start', func_eggex.MatchAccess(mem, func_eggex.S))
_SetGlobalFunc(mem, '_end', func_eggex.MatchAccess(mem, func_eggex.E))
_SetGlobalFunc(mem, '_match', func_eggex.MatchFunc(mem, func_eggex.G))
_SetGlobalFunc(mem, '_start', func_eggex.MatchFunc(mem, func_eggex.S))
_SetGlobalFunc(mem, '_end', func_eggex.MatchFunc(mem, func_eggex.E))

_SetGlobalFunc(mem, 'join', func_misc.Join())
_SetGlobalFunc(mem, 'maybe', func_misc.Maybe())
Expand Down
18 changes: 18 additions & 0 deletions spec/ysh-regex.test.sh
Expand Up @@ -85,6 +85,24 @@ var pat3 = / @pat 'def' /
## STDOUT:
## END

#### Eggex with translation preference has arbitrary flags
shopt -s ysh:upgrade

# TODO: can provide introspection so users can translate it?
# This is kind of a speculative corner of the language.

var pat = / d+ ; ignorecase ; PCRE /

# This uses ERE, as a test
if ('ab 12' ~ pat) {
echo yes
}

## STDOUT:
yes
## END


#### Positional captures with _match
shopt -s ysh:all

Expand Down
6 changes: 3 additions & 3 deletions ysh/expr_to_ast.py
Expand Up @@ -861,11 +861,11 @@ def _Eggex(self, p_node):

# Canonicalize and validate flags for ERE only. Default is ERE.
if trans_pref is None or lexer.TokenVal(trans_pref) == 'ERE':
ere_flags = regex_translate.EncodeFlagsEre(flags)
canonical_flags = regex_translate.CanonicalFlags(flags)
else:
ere_flags = None
canonical_flags = None

return Eggex(left, regex, flags, trans_pref, ere_flags)
return Eggex(left, regex, flags, trans_pref, canonical_flags)

def YshCasePattern(self, pnode):
# type: (PNode) -> pat_t
Expand Down
21 changes: 20 additions & 1 deletion ysh/regex_translate.py
Expand Up @@ -28,6 +28,8 @@
if TYPE_CHECKING:
from _devbuild.gen.syntax_asdl import re_t

from libc import REG_ICASE, REG_NEWLINE

_ = log

PERL_CLASS = {
Expand Down Expand Up @@ -350,10 +352,13 @@ def AsPosixEre(eggex):
return eggex.as_ere


def EncodeFlagsEre(flags):
def CanonicalFlags(flags):
# type: (List[EggexFlag]) -> str
"""
Raises PARSE error on invalid flags.
In theory we could encode directly to integers like REG_ICASE, but a string
like like 'i' makes the error message slightly more legible.
"""
letters = [] # type: List[str]
for flag in flags:
Expand All @@ -370,3 +375,17 @@ def EncodeFlagsEre(flags):
# Normalize for comparison
letters.sort()
return ''.join(letters)


def LibcFlags(canonical_flags):
# type: (str) -> int
libc_flags = 0
for ch in canonical_flags:
if ch == 'i':
libc_flags |= REG_ICASE
elif ch == 'n':
libc_flags |= REG_NEWLINE
else:
# regex_translate should prevent this
raise AssertionError()
return libc_flags
12 changes: 6 additions & 6 deletions ysh/regex_translate_test.py
Expand Up @@ -21,21 +21,21 @@ def _Name(s):

class RegexTranslateTest(unittest.TestCase):

def testEncodeFlagsEre(self):
def testCanonicalFlags(self):
reg_icase = _Name('reg_icase')
i = _Name('i') # abbreviation
reg_newline = _Name('reg_newline')
bad = _Name('bad')

flags = [EggexFlag(False, reg_icase)]
self.assertEqual('i', regex_translate.EncodeFlagsEre(flags))
self.assertEqual('i', regex_translate.CanonicalFlags(flags))

flags = [EggexFlag(False, i)]
self.assertEqual('i', regex_translate.EncodeFlagsEre(flags))
self.assertEqual('i', regex_translate.CanonicalFlags(flags))

flags = [EggexFlag(False, bad)]
try:
regex_translate.EncodeFlagsEre(flags)
regex_translate.CanonicalFlags(flags)
except error.Parse as e:
print(e.UserErrorString())
else:
Expand All @@ -44,8 +44,8 @@ def testEncodeFlagsEre(self):
order1 = [EggexFlag(False, reg_icase), EggexFlag(False, reg_newline)]
order2 = [EggexFlag(False, reg_newline), EggexFlag(False, reg_icase)]

self.assertEqual('in', regex_translate.EncodeFlagsEre(order1))
self.assertEqual('in', regex_translate.EncodeFlagsEre(order2))
self.assertEqual('in', regex_translate.CanonicalFlags(order1))
self.assertEqual('in', regex_translate.CanonicalFlags(order2))


if __name__ == '__main__':
Expand Down
11 changes: 2 additions & 9 deletions ysh/val_ops.py
Expand Up @@ -14,7 +14,6 @@
from typing import TYPE_CHECKING, cast, Dict, List, Optional

import libc
from libc import REG_ICASE, REG_NEWLINE

if TYPE_CHECKING:
from core import state
Expand Down Expand Up @@ -451,14 +450,8 @@ def RegexMatch(left, right, mem):
elif case(value_e.Eggex):
right = cast(value.Eggex, UP_right)
right_s = regex_translate.AsPosixEre(right)
for ch in right.canonical_flags:
if ch == 'i':
regex_flags |= REG_ICASE
elif ch == 'n':
regex_flags |= REG_NEWLINE
else:
# regex_translate should prevent this
raise AssertionError()
if right.canonical_flags is not None:
regex_flags = regex_translate.LibcFlags(right.canonical_flags)
else:
raise error.TypeErr(right, 'Expected Str or Regex for RHS of ~',
loc.Missing)
Expand Down

0 comments on commit 7b76121

Please sign in to comment.