Skip to content

Commit

Permalink
[eggex] Canonicalize regex flags at parse time
Browse files Browse the repository at this point in the history
Only when translating to ERE
  • Loading branch information
Andy Chu committed Dec 13, 2023
1 parent d52d495 commit a4845be
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 21 deletions.
8 changes: 7 additions & 1 deletion frontend/syntax.asdl
Expand Up @@ -289,7 +289,13 @@ module syntax
| YshExpr(expr e)

EggexFlag = (bool negated, Token flag)
Eggex = (Token left, re regex, List[EggexFlag] flags, Token? trans_pref)

# ere_flags is a canonical version of flags that can be compared for
# equality. This is so we can splice eggexes correctly, e.g.
# / 'abc' @pat ; i /
Eggex = (
Token left, re regex, List[EggexFlag] flags, Token? trans_pref,
str? ere_flags)

pat =
Else
Expand Down
18 changes: 13 additions & 5 deletions test/ysh-parse-errors.sh
Expand Up @@ -1291,11 +1291,22 @@ test-eggex-capture() {


test-eggex-flags() {
_should-parse '= / d+ ; ignorecase /'
_should-parse '= / d+ ; reg_icase /'
_should-parse '= / d+ ; i /' # shortcut

# can't negate these
_parse-error '= / d+ ; !i /'

# typo should be parse error
_parse-error '= / d+ ; reg_oops /'

# PCRE should not validate
_should-parse '= / d+ ; !i; PCRE /'
_should-parse '= / d+ ; reg_oops; PCRE /'

# ERE means is the default; it's POSIX ERE
# Other option is PCRE
_should-parse '= / d+ ; ignorecase !multiline ; ERE /'
_should-parse '= / d+ ; i reg_newline ; ERE /'
_should-parse '= / d+ ; ; ERE /'

# trailing ; is OK
Expand All @@ -1304,9 +1315,6 @@ test-eggex-flags() {
# doesn't make sense
_parse-error '= / d+ ; ; /'
_parse-error '= / d+ ; ; ; /'

# typo should be parse error
#_parse-error '= / d+ ; ignorecas /'
}

#
Expand Down
9 changes: 8 additions & 1 deletion ysh/expr_to_ast.py
Expand Up @@ -51,6 +51,7 @@
from mycpp import mylib
from mycpp.mylib import log, tagswitch
from ysh import expr_parse
from ysh import regex_translate

from typing import TYPE_CHECKING, Dict, List, Tuple, Optional, cast
if TYPE_CHECKING:
Expand Down Expand Up @@ -858,7 +859,13 @@ def _Eggex(self, p_node):
i += 1
trans_pref = p_node.GetChild(i).tok

return Eggex(left, regex, flags, trans_pref)
# Canonicalize and validate flags for ERE only. Default is ERE.
if trans_pref is None or lexer.TokenVal(trans_pref) == 'ERE':
ere_flags = regex_translate.EncodeFlagsEre(flags)
else:
ere_flags = None

return Eggex(left, regex, flags, trans_pref, ere_flags)

def YshCasePattern(self, pnode):
# type: (PNode) -> pat_t
Expand Down
14 changes: 7 additions & 7 deletions ysh/regex_translate.py
Expand Up @@ -18,7 +18,7 @@
)
from _devbuild.gen.id_kind_asdl import Id
from _devbuild.gen.value_asdl import value
from core.error import e_die
from core.error import e_die, p_die
from frontend import lexer
from mycpp.mylib import log, tagswitch
from osh import glob_ # for ExtendedRegexEscape
Expand Down Expand Up @@ -350,23 +350,23 @@ def AsPosixEre(eggex):
return eggex.as_ere


def _EncodeFlags(flags):
def EncodeFlagsEre(flags):
# type: (List[EggexFlag]) -> str
"""
Raises fatal exception on invalid flags.
Raises PARSE error on invalid flags.
"""
letters = [] # type: List[str]
for flag in flags:
if flag.negated:
e_die("Flag can't be negated", flag.flag)
p_die("Flag can't be negated", flag.flag)
flag_name = lexer.TokenVal(flag.flag)
if flag_name in ('i', 'reg_icase'):
letters.append('i')
elif flag_name == 'reg_newline':
letters.append('n')
else:
e_die("Invalid regex flag %r" % flag_name, flag.flag)
p_die("Invalid regex flag %r" % flag_name, flag.flag)

# Normalize for comparison
return ''.join(sorted(letters))

letters.sort()
return ''.join(letters)
15 changes: 8 additions & 7 deletions ysh/regex_translate_test.py
Expand Up @@ -7,6 +7,7 @@
from _devbuild.gen.syntax_asdl import EggexFlag, Token, source, SourceLine

from asdl import runtime
from core import error
from ysh import regex_translate


Expand All @@ -20,31 +21,31 @@ def _Name(s):

class RegexTranslateTest(unittest.TestCase):

def testEncodeFlags(self):
def testEncodeFlagsEre(self):
reg_icase = _Name('reg_icase')
i = _Name('i') # abbreviation
reg_newline = _Name('reg_newline')
bad = _Name('bad')

flags = [EggexFlag(False, reg_icase)]
self.assertEqual('i', regex_translate._EncodeFlags(flags))
self.assertEqual('i', regex_translate.EncodeFlagsEre(flags))

flags = [EggexFlag(False, i)]
self.assertEqual('i', regex_translate._EncodeFlags(flags))
self.assertEqual('i', regex_translate.EncodeFlagsEre(flags))

flags = [EggexFlag(False, bad)]
try:
regex_translate._EncodeFlags(flags)
except Exception as e:
regex_translate.EncodeFlagsEre(flags)
except error.Parse as e:
print(e.UserErrorString())
else:
self.fail('Should have failed')

order1 = [EggexFlag(False, reg_icase), EggexFlag(False, reg_newline)]
order2 = [EggexFlag(False, reg_newline), EggexFlag(False, reg_icase)]

self.assertEqual('in', regex_translate._EncodeFlags(order1))
self.assertEqual('in', regex_translate._EncodeFlags(order2))
self.assertEqual('in', regex_translate.EncodeFlagsEre(order1))
self.assertEqual('in', regex_translate.EncodeFlagsEre(order2))


if __name__ == '__main__':
Expand Down

0 comments on commit a4845be

Please sign in to comment.