Skip to content

Commit

Permalink
[eggex refactor] Prepare for type conversion funcs
Browse files Browse the repository at this point in the history
MatchMethod and MatchFunc inherit from a base class that has an
ExprEvaluator.

[ASDL] Make note of inconsistent None/nullptr handling.
  • Loading branch information
Andy C committed Dec 18, 2023
1 parent 4b26395 commit f738a3a
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 51 deletions.
3 changes: 2 additions & 1 deletion asdl/gen_cpp.py
Expand Up @@ -541,7 +541,8 @@ def _EmitCodeForField(self, abbrev, field, counter):
child_code_str, none_guard = _HNodeExpr(abbrev, item_type,
iter_name)
if none_guard: # e.g. for List[Optional[value_t]]
# _ means None/nullptr, like asdl/runtime.py NewLeaf
# TODO: could consolidate with asdl/runtime.py NewLeaf(), which
# also uses _ to mean None/nullptr
self.Emit(
' hnode_t* h = (%s == nullptr) ? Alloc<hnode::Leaf>(StrFromC("_"), color_e::OtherConst) : %s;'
% (iter_name, child_code_str))
Expand Down
2 changes: 2 additions & 0 deletions asdl/gen_python.py
Expand Up @@ -256,6 +256,8 @@ def _EmitCodeForField(self, abbrev, field, counter):
iter_name)

if none_guard: # e.g. for List[Optional[value_t]]
# TODO: could consolidate with asdl/runtime.py NewLeaf(), which
# also uses _ to mean None/nullptr
self.Emit(
' h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
% (iter_name, child_code_str))
Expand Down
12 changes: 12 additions & 0 deletions asdl/runtime.py
Expand Up @@ -27,7 +27,19 @@ def NewRecord(node_type):

def NewLeaf(s, e_color):
# type: (Optional[str], color_t) -> hnode.Leaf
"""
TODO: _EmitCodeForField in asdl/gen_{cpp,python}.py does something like
this for non-string types. We should keep the style consistent.
It's related to the none_guard return value of _HNodeExpr().
The problem there is that we call i0->PrettyTree() or
i0->AbbreviatedTree(). Although it's not actually polymorphic in C++, only
Python, so we could handle the nullptr case.
i.e. PrettyTree() could be a free function using static dispatch, not a
member. And then it can handle the nullptr case.
"""
# for repr of BashArray, which can have 'None'
if s is None:
return hnode.Leaf('_', color_e.OtherConst)
Expand Down
93 changes: 53 additions & 40 deletions builtin/func_eggex.py
Expand Up @@ -12,7 +12,10 @@
from frontend import typed_args
from mycpp.mylib import log, tagswitch

from typing import List, Optional, cast
from typing import List, Optional, cast, TYPE_CHECKING
if TYPE_CHECKING:
from ysh.expr_eval import ExprEvaluator


_ = log

Expand All @@ -21,32 +24,39 @@
E = 2 # _end()


def _GetMatch(s, indices, i, to_return, blame_loc):
# type: (str, List[int], int, int, loc_t) -> value_t
num_groups = len(indices) / 2 # including group 0
if i < num_groups:
start = indices[2 * i]
if to_return == S:
return value.Int(start)

end = indices[2 * i + 1]
if to_return == E:
return value.Int(end)
class _MatchCallable(vm._Callable):

if start == -1:
return value.Null
else:
# TODO: Can apply type conversion function
# See osh/prompt.py:
# val = self.expr_ev.PluginCall(func_val, pos_args)
return value.Str(s[start:end])
else:
if num_groups == 0:
msg = 'No regex capture groups'
def __init__(self, to_return, expr_ev):
# type: (int, Optional[ExprEvaluator]) -> None
self.to_return = to_return
self.expr_ev = expr_ev

def _GetMatch(self, s, indices, i, blame_loc):
# type: (str, List[int], int, loc_t) -> value_t
num_groups = len(indices) / 2 # including group 0
if i < num_groups:
start = indices[2 * i]
if self.to_return == S:
return value.Int(start)

end = indices[2 * i + 1]
if self.to_return == E:
return value.Int(end)

if start == -1:
return value.Null
else:
# TODO: Can apply type conversion function
# See osh/prompt.py:
# val = self.expr_ev.PluginCall(func_val, pos_args)
return value.Str(s[start:end])
else:
msg = 'Expected capture group less than %d, got %d' % (num_groups,
i)
raise error.Expr(msg, blame_loc)
if num_groups == 0:
msg = 'No regex capture groups'
else:
msg = 'Expected capture group less than %d, got %d' % (num_groups,
i)
raise error.Expr(msg, blame_loc)


def _GetGroupIndex(group, capture_names, blame_loc):
Expand All @@ -73,7 +83,7 @@ def _GetGroupIndex(group, capture_names, blame_loc):
return group_index


class MatchFunc(vm._Callable):
class MatchFunc(_MatchCallable):
"""
_group(i)
_start(i)
Expand All @@ -85,12 +95,10 @@ class MatchFunc(vm._Callable):
Ditto for _start() and _end()
"""

def __init__(self, mem, to_return):
# type: (state.Mem, int) -> None
vm._Callable.__init__(self)
def __init__(self, to_return, expr_ev, mem):
# type: (int, Optional[ExprEvaluator], state.Mem) -> None
_MatchCallable.__init__(self, to_return, expr_ev)
self.mem = mem
self.to_return = to_return

def Call(self, rd):
# type: (typed_args.Reader) -> value_t
Expand All @@ -101,20 +109,18 @@ def Call(self, rd):
s, indices, capture_names = self.mem.GetRegexIndices()
group_index = _GetGroupIndex(group, capture_names, rd.LeftParenToken())

return _GetMatch(s, indices, group_index, self.to_return,
rd.LeftParenToken())
return self._GetMatch(s, indices, group_index, rd.LeftParenToken())


class MatchMethod(vm._Callable):
class MatchMethod(_MatchCallable):
"""
m => group(i)
m => start(i)
m => end(i)
"""

def __init__(self, to_return):
# type: (int) -> None
self.to_return = to_return
def __init__(self, to_return, expr_ev):
# type: (int, Optional[ExprEvaluator]) -> None
_MatchCallable.__init__(self, to_return, expr_ev)

def Call(self, rd):
# type: (typed_args.Reader) -> value_t
Expand All @@ -127,8 +133,15 @@ def Call(self, rd):
group_index = _GetGroupIndex(group, m.capture_names,
rd.LeftParenToken())

return _GetMatch(m.s, m.indices, group_index, self.to_return,
rd.LeftParenToken())
#log('group_index %d', group_index)
#log('m.convert_funcs %s', m.convert_funcs)
convert_func = None # type: value_t
if len(m.convert_funcs): # for ERE string, it's []
if group_index != 0: # doesn't have a name or type attached to it
convert_func = m.convert_funcs[group_index - 1]
#log('conv %s', convert_func)

return self._GetMatch(m.s, m.indices, group_index, rd.LeftParenToken())


# vim: sw=4
12 changes: 6 additions & 6 deletions core/shell.py
Expand Up @@ -774,9 +774,9 @@ def Main(
}

methods[value_e.Match] = {
'group': func_eggex.MatchMethod(func_eggex.G),
'start': func_eggex.MatchMethod(func_eggex.S),
'end': func_eggex.MatchMethod(func_eggex.E),
'group': func_eggex.MatchMethod(func_eggex.G, expr_ev),
'start': func_eggex.MatchMethod(func_eggex.S, None),
'end': func_eggex.MatchMethod(func_eggex.E, None),
}

methods[value_e.IO] = {
Expand Down Expand Up @@ -815,11 +815,11 @@ def Main(

_SetGlobalFunc(mem, 'len', func_misc.Len())

g = func_eggex.MatchFunc(mem, func_eggex.G)
g = func_eggex.MatchFunc(func_eggex.G, expr_ev, mem)
_SetGlobalFunc(mem, '_group', g)
_SetGlobalFunc(mem, '_match', g) # TODO: remove this backward compat alias
_SetGlobalFunc(mem, '_start', func_eggex.MatchFunc(mem, func_eggex.S))
_SetGlobalFunc(mem, '_end', func_eggex.MatchFunc(mem, func_eggex.E))
_SetGlobalFunc(mem, '_start', func_eggex.MatchFunc(func_eggex.S, None, mem))
_SetGlobalFunc(mem, '_end', func_eggex.MatchFunc(func_eggex.E, None, mem))

_SetGlobalFunc(mem, 'join', func_misc.Join())
_SetGlobalFunc(mem, 'maybe', func_misc.Maybe())
Expand Down
29 changes: 25 additions & 4 deletions ysh/expr_eval.py
Expand Up @@ -340,7 +340,6 @@ def PluginCall(self, func_val, pos_args):
arg_list = ArgList.CreateNull() # There's no call site
rd = typed_args.Reader(pos_args, named_args, arg_list)

# TODO: catch exceptions
try:
val = func_proc.CallUserFunc(func_val, rd, self.mem,
self.cmd_ev)
Expand All @@ -356,6 +355,22 @@ def PluginCall(self, func_val, pos_args):

return val

def CallConvertFunc(self, func_val, arg):
# type: (value_t, value_t) -> value_t
""" For Eggex captures """
with state.ctx_YshExpr(self.mutable_opts):
pos_args = [arg]
named_args = {} # type: Dict[str, value_t]
arg_list = ArgList.CreateNull() # There's no call site
rd = typed_args.Reader(pos_args, named_args, arg_list)

# TODO: Use logic from _EvalFuncCall

#val = func_proc.CallUserFunc(func_val, rd, self.mem, self.cmd_ev)
val = None

return val

def SpliceValue(self, val, part):
# type: (value_t, word_part.Splice) -> List[str]
""" write -- @myvar """
Expand Down Expand Up @@ -1347,9 +1362,14 @@ def _EvalEggex(self, node, parent_flags, convert_funcs):

elif case(value_e.Eggex):
val = cast(value.Eggex, UP_val)

# Splicing means we get the conversion funcs too.
convert_funcs.extend(val.convert_funcs)

# Splicing requires flags to match. This check is
# transitive.
to_splice = val.spliced

if val.canonical_flags != parent_flags:
e_die(
"Expected eggex flags %r, but got %r" %
Expand All @@ -1371,13 +1391,14 @@ def _EvalEggex(self, node, parent_flags, convert_funcs):
def EvalEggex(self, node):
# type: (Eggex) -> value.Eggex

# Splice and check flags consistency
# Splice, check flags consistency, and accumulate convert_funcs indexed
# by capture group
convert_funcs = [] # type: List[Optional[value_t]]
spliced = self._EvalEggex(node.regex, node.canonical_flags,
convert_funcs)
#log('convert_funcs %s', convert_funcs)

# as_ere and capture_names filled in during translation
# TODO: func_names should be done above
# as_ere and capture_names filled by ~ operator or Str method
return value.Eggex(spliced, node.canonical_flags, convert_funcs, None,
[])

Expand Down

0 comments on commit f738a3a

Please sign in to comment.