Skip to content

Commit

Permalink
[eggex] Implement named capture groups
Browse files Browse the repository at this point in the history
They can be accessed with

    _group('month')
    _start('month')
    _end('month')

and

    m => group('month')
    m => start('month')
    m => end('month')
  • Loading branch information
Andy C committed Dec 17, 2023
1 parent 0f6fc2a commit cbbb4e6
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 51 deletions.
70 changes: 38 additions & 32 deletions builtin/func_eggex.py
Expand Up @@ -12,7 +12,7 @@
from frontend import typed_args
from mycpp.mylib import log, tagswitch

from typing import List, cast
from typing import List, Optional, cast

_ = log

Expand Down Expand Up @@ -46,9 +46,36 @@ def _GetMatch(s, indices, i, to_return, blame_loc):
raise error.UserError(2, msg, blame_loc)


def _GetGroupIndex(group, capture_names, blame_loc):
# type: (value_t, List[Optional[str]], loc_t) -> int
UP_group = group

with tagswitch(group) as case:
if case(value_e.Int):
group = cast(value.Int, UP_group)
group_index = group.i

elif case(value_e.Str):
group = cast(value.Str, UP_group)
group_index = -1
for i, name in enumerate(capture_names):
if name == group.s:
group_index = i + 1 # 1-based
break
if group_index == -1:
raise error.Expr('No such group %r' % group.s, blame_loc)
else:
raise error.TypeErr(group, 'Expected Int or Str', blame_loc)
return group_index


class MatchFunc(vm._Callable):
"""
_group(0) or _group() : get the whole match
_group(i)
_start(i)
_end(i)
_group(0) : get the whole match
_group(1) to _group(N): get a submatch
_group('month') : get group by name
Expand All @@ -65,22 +92,13 @@ def Call(self, rd):
# type: (typed_args.Reader) -> value_t

group = rd.PosValue()
UP_group = group
with tagswitch(group) as case:
if case(value_e.Int):
group = cast(value.Int, UP_group)
i = group.i
elif case(value_e.Str):
group = cast(value.Str, UP_group)
# TODO: calculate from mem registers
i = 0
else:
raise error.TypeErr(group, 'Expected Int or Str',
rd.LeftParenToken())
rd.Done()

s, indices = self.mem.GetRegexIndices()
s, indices, capture_names = self.mem.GetRegexIndices()
group_index = _GetGroupIndex(group, capture_names, rd.LeftParenToken())

return _GetMatch(s, indices, i, self.to_return, rd.LeftParenToken())
return _GetMatch(s, indices, group_index, self.to_return,
rd.LeftParenToken())


class MatchMethod(vm._Callable):
Expand All @@ -99,25 +117,13 @@ def Call(self, rd):

# This is guaranteed
m = rd.PosMatch()

group = rd.PosValue()
UP_group = group
with tagswitch(group) as case:
if case(value_e.Int):
group = cast(value.Int, UP_group)
i = group.i
elif case(value_e.Str):
group = cast(value.Str, UP_group)
# TODO: calculate from mem registers
i = 0
else:
raise error.TypeErr(group, 'Expected Int or Str',
rd.LeftParenToken())

rd.Done()

#log('group %d, s %r indices %s', i, m.s, m.indices)
return _GetMatch(m.s, m.indices, i, self.to_return,
group_index = _GetGroupIndex(group, m.capture_names,
rd.LeftParenToken())

return _GetMatch(m.s, m.indices, group_index, self.to_return,
rd.LeftParenToken())


Expand Down
2 changes: 1 addition & 1 deletion builtin/method_str.py
Expand Up @@ -71,6 +71,7 @@ def Call(self, rd):
SEARCH = 0
LEFT_MATCH = 1


class SearchMatch(vm._Callable):

def __init__(self, which_method):
Expand Down Expand Up @@ -109,4 +110,3 @@ def Call(self, rd):

return value.Match(string, indices, eggex_val.capture_names,
eggex_val.func_names)

21 changes: 15 additions & 6 deletions core/state.py
Expand Up @@ -971,6 +971,7 @@ def __init__(self, mem):

mem.regex_indices.append([])
mem.regex_string.append('')
mem.capture_names.append([])
self.mem = mem

def __enter__(self):
Expand All @@ -981,8 +982,11 @@ def __exit__(self, type, value, traceback):
# type: (Any, Any, Any) -> None
self.mem.regex_string.pop()
self.mem.regex_indices.pop()
self.mem.capture_names.pop()

self.mem.process_sub_status.pop()
self.mem.pipe_status.pop()

self.mem.try_status.pop()
self.mem.last_status.pop()

Expand Down Expand Up @@ -1070,6 +1074,7 @@ def __init__(self, dollar0, argv, arena, debug_stack):
# 0 is the whole match, 1..n are submatches
self.regex_indices = [[]] # type: List[List[int]]
self.regex_string = [''] # type: List[str]
self.capture_names = [[]] # type: List[List[Optional[str]]]

self.last_bg_pid = -1 # Uninitialized value mutable public variable

Expand Down Expand Up @@ -2151,18 +2156,22 @@ def IsGlobalScope(self):

def ClearRegexIndices(self):
# type: () -> None
top = self.regex_indices[-1]
del top[:] # no clear() in Python 2
indices = self.regex_indices[-1]
del indices[:] # no clear() in Python 2
self.regex_string[-1] = ''
names = self.capture_names[-1]
del names[:]

def SetRegexIndices(self, s, indices):
# type: (str, List[int]) -> None
def SetRegexIndices(self, s, indices, capture_names):
# type: (str, List[int], List[Optional[str]]) -> None
self.regex_string[-1] = s
self.regex_indices[-1] = indices
self.capture_names[-1] = capture_names

def GetRegexIndices(self):
# type: () -> Tuple[str, List[int]]
return self.regex_string[-1], self.regex_indices[-1]
# type: () -> Tuple[str, List[int], List[Optional[str]]]
return (self.regex_string[-1], self.regex_indices[-1],
self.capture_names[-1])


#
Expand Down
2 changes: 1 addition & 1 deletion osh/sh_expr_eval.py
Expand Up @@ -1077,7 +1077,7 @@ def EvalB(self, node):
e_die_status(2, e.message, loc.Word(node.right))

if indices is not None:
self.mem.SetRegexIndices(s1, indices)
self.mem.SetRegexIndices(s1, indices, [])
return True
else:
self.mem.ClearRegexIndices()
Expand Down
54 changes: 48 additions & 6 deletions spec/ysh-regex-api.test.sh
Expand Up @@ -152,7 +152,7 @@ start=-1 end=-1

#### Str->search() method returns value.Match object

var s = '= hi5- bye6-'
var s = '= Hi5- Bye6-'

var m = s => search(/ <capture [a-z]+ > <capture d+> '-' ; i /)
echo "g0 $[m => start(0)] $[m => end(0)] $[m => group(0)]"
Expand All @@ -168,12 +168,12 @@ echo "g1 $[m => start(1)] $[m => end(1)] $[m => group(1)]"
echo "g2 $[m => start(2)] $[m => end(2)] $[m => group(2)]"

## STDOUT:
g0 2 6 hi5-
g1 2 4 hi
g0 2 6 Hi5-
g1 2 4 Hi
g2 4 5 5
---
g0 7 12 bye6-
g1 7 10 bye
g0 7 12 Bye6-
g1 7 10 Bye
g2 10 11 6
## END

Expand Down Expand Up @@ -209,6 +209,19 @@ pat=([[:digit:]]+)-
34-
## END


#### Str->search() accepts ERE string

var s = '= hi5- bye6-'

var m = s => search('([[:alpha:]]+)([[:digit:]]+)-')
echo "g0 $[m => start(0)] $[m => end(0)] $[m => group(0)]"
echo "g1 $[m => start(1)] $[m => end(1)] $[m => group(1)]"
echo "g2 $[m => start(2)] $[m => end(2)] $[m => group(2)]"

## STDOUT:
## END

#### Str->leftMatch() can implement lexer pattern

shopt -s ysh:upgrade
Expand Down Expand Up @@ -262,16 +275,45 @@ null/ab/null/
pos=2
## END

#### Named captures with _group
#### Named captures with m => group()
shopt -s ysh:all

var s = 'zz 2020-08-20'
var pat = /<capture d+ as year> '-' <capture d+ as month>/

var m = s => search(pat)
argv.py $[m => group('year')] $[m => group('month')]
echo $[m => start('year')] $[m => end('year')]
echo $[m => start('month')] $[m => end('month')]

argv.py $[m => group('oops')]
echo 'error'

## status: 3
## STDOUT:
['2020', '08']
3 7
8 10
## END

#### Named captures with _group() _start() _end()
shopt -s ysh:all

var x = 'zz 2020-08-20'

if (x ~ /<capture d+ as year> '-' <capture d+ as month>/) {
argv.py $[_group('year')] $[_group('month')]
echo $[_start('year')] $[_end('year')]
echo $[_start('month')] $[_end('month')]
}

argv.py $[_group('oops')]

## status: 3
## STDOUT:
['2020', '08']
3 7
8 10
## END

#### Named Capture Decays Without Name
Expand Down
4 changes: 2 additions & 2 deletions ysh/expr_to_ast.py
Expand Up @@ -1462,12 +1462,12 @@ def _ReAtom(self, p_atom):

tok = p_atom.GetChild(i).tok
if tok.id == Id.Expr_As:
as_name = p_atom.GetChild(i+1).tok
as_name = p_atom.GetChild(i + 1).tok
i += 2

tok = p_atom.GetChild(i).tok
if tok.id == Id.Arith_Colon:
func_name = p_atom.GetChild(i+1).tok
func_name = p_atom.GetChild(i + 1).tok

# TODO: is it possible to output the capture name <-> index mapping
# here for POSIX ERE?
Expand Down
11 changes: 8 additions & 3 deletions ysh/val_ops.py
Expand Up @@ -440,17 +440,22 @@ def RegexMatch(left, right, mem):
mem: Whether to set or clear matches
"""
UP_right = right
right_s = None # type: str
regex_flags = 0

with tagswitch(right) as case:
if case(value_e.Str): # plain ERE
right = cast(value.Str, UP_right)

right_s = right.s
regex_flags = 0
capture_names = [] # type: List[Optional[str]]

elif case(value_e.Eggex):
right = cast(value.Eggex, UP_right)

right_s = regex_translate.AsPosixEre(right)
regex_flags = regex_translate.LibcFlags(right.canonical_flags)
capture_names = right.capture_names

else:
raise error.TypeErr(right, 'Expected Str or Regex for RHS of ~',
loc.Missing)
Expand All @@ -467,7 +472,7 @@ def RegexMatch(left, right, mem):
indices = libc.regex_search(right_s, regex_flags, left_s, 0)
if indices is not None:
if mem:
mem.SetRegexIndices(left_s, indices)
mem.SetRegexIndices(left_s, indices, capture_names)
return True
else:
if mem:
Expand Down

0 comments on commit cbbb4e6

Please sign in to comment.