Skip to content

Commit

Permalink
[refactor] Extract core/code.py, to unify Proc/Func arg binding
Browse files Browse the repository at this point in the history
Had to work around some mycpp issues.

Also

- Simplify _ParamGroup() function
  • Loading branch information
Andy C committed Aug 28, 2023
1 parent e2f9c36 commit 185bad8
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 130 deletions.
149 changes: 149 additions & 0 deletions core/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python2
"""
code.py: User-defined funcs and procs
"""
from __future__ import print_function

from _devbuild.gen.runtime_asdl import value, value_t, scope_e
from _devbuild.gen.syntax_asdl import proc_sig, proc_sig_e

from core import error
from core.error import e_die
from core import state
from core import vm
from frontend import lexer
from frontend import location

from typing import List, Dict, cast, TYPE_CHECKING
if TYPE_CHECKING:
from _devbuild.gen.syntax_asdl import command, loc_t
from _devbuild.gen.runtime_asdl import Proc
from core import ui
from osh import cmd_eval


class UserFunc(vm._Callable):
"""A user-defined function."""

def __init__(self, name, node, mem, cmd_ev):
# type: (str, command.Func, state.Mem, cmd_eval.CommandEvaluator) -> None
self.name = name
self.node = node
self.cmd_ev = cmd_ev
self.mem = mem

def Call(self, pos_args, named_args):
# type: (List[value_t], Dict[str, value_t]) -> value_t
nargs = len(pos_args)
expected = len(self.node.pos_params)
if self.node.pos_splat:
if nargs < expected:
raise error.TypeErrVerbose(
"%s() expects at least %d arguments but %d were given" %
(self.name, expected, nargs), self.node.keyword)
elif nargs != expected:
raise error.TypeErrVerbose(
"%s() expects %d arguments but %d were given" %
(self.name, expected, nargs), self.node.keyword)

nargs = len(named_args)
expected = len(self.node.named_params)
if nargs != expected:
raise error.TypeErrVerbose(
"%s() expects %d named arguments but %d were given" %
(self.name, expected, nargs), self.node.keyword)

with state.ctx_FuncCall(self.cmd_ev.mem, self):
nargs = len(self.node.pos_params)
for i in xrange(0, nargs):
pos_arg = pos_args[i]
pos_param = self.node.pos_params[i]

arg_name = location.LName(lexer.TokenVal(pos_param.name))
self.mem.SetValue(arg_name, pos_arg, scope_e.LocalOnly)

if self.node.pos_splat:
other_args = value.List(pos_args[nargs:])
arg_name = location.LName(lexer.TokenVal(self.node.pos_splat))
self.mem.SetValue(arg_name, other_args, scope_e.LocalOnly)

# TODO: pass named args

try:
self.cmd_ev._Execute(self.node.body)

return value.Null # implicit return
except vm.ValueControlFlow as e:
return e.value
except vm.IntControlFlow as e:
raise AssertionError('IntControlFlow in func')

raise AssertionError('unreachable')


def BindProcArgs(proc, argv, arg0_loc, mem, errfmt):
# type: (Proc, List[str], loc_t, state.Mem, ui.ErrorFormatter) -> int

UP_sig = proc.sig
if UP_sig.tag() != proc_sig_e.Closed: # proc is-closed ()
return 0

sig = cast(proc_sig.Closed, UP_sig)

n_args = len(argv)
for i, p in enumerate(sig.words):

# proc p(out Ref)
is_out_param = (p.type is not None and p.type.name == 'Ref')

param_name = p.name.tval
if i < n_args:
arg_str = argv[i]

# If we have myproc(p), and call it with myproc :arg, then bind
# __p to 'arg'. That is, the param has a prefix ADDED, and the arg
# has a prefix REMOVED.
#
# This helps eliminate "nameref cycles".
if is_out_param:
param_name = '__' + param_name

if not arg_str.startswith(':'):
# TODO: Point to the exact argument
e_die(
'Invalid argument %r. Expected a name starting with :'
% arg_str)
arg_str = arg_str[1:]

val = value.Str(arg_str) # type: value_t
else:
val = proc.defaults[i]
if val is None:
e_die("No value provided for param %r" %
p.name.tval)

if is_out_param:
flags = state.SetNameref
else:
flags = 0

mem.SetValue(location.LName(param_name), val, scope_e.LocalOnly,
flags=flags)

n_params = len(sig.words)
if sig.rest_words:
items = [value.Str(s)
for s in argv[n_params:]] # type: List[value_t]
leftover = value.List(items)
mem.SetValue(location.LName(sig.rest_words.tval), leftover,
scope_e.LocalOnly)
else:
if n_args > n_params:
# TODO: Raise an exception?
errfmt.Print_(
"proc %r expected %d arguments, but got %d" %
(proc.name, n_params, n_args), arg0_loc)
# This should be status 2 because it's like a usage error.
return 2

return 0
4 changes: 2 additions & 2 deletions core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from _devbuild.gen.option_asdl import option_t
from _devbuild.gen.runtime_asdl import Proc
from core import alloc
from core import code
from osh import sh_expr_eval
from osh.cmd_eval import Func

# This was derived from bash --norc -c 'argv "$COMP_WORDBREAKS".
# Python overwrites this to something Python-specific in Modules/readline.c, so
Expand Down Expand Up @@ -1121,7 +1121,7 @@ class ctx_FuncCall(object):
"""For func calls."""

def __init__(self, mem, func):
# type: (Mem, Func) -> None
# type: (Mem, code.UserFunc) -> None
mem.PushCall(func.name, func.node.name, None)
self.mem = mem

Expand Down
2 changes: 1 addition & 1 deletion frontend/syntax.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ module syntax
# - no 'Token right' for now, doesn't appear to be used
ShArrayLiteral = (Token left, List[word] words, Token right)

# For both proc and func
# Currently used for func
# Note that ...pos is expr.Spread. TODO: Make NamedArg consistent
ArgList = (
Token left, List[expr] positional, List[NamedArg] named, Token right
Expand Down
130 changes: 8 additions & 122 deletions osh/cmd_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
)
from _devbuild.gen.types_asdl import redir_arg_type_e

from core import code
from core import dev
from core import error
from core.error import e_die, e_die_status
Expand Down Expand Up @@ -224,65 +225,6 @@ def PlusEquals(old_val, val):
return val


class Func(vm._Callable):
"""A user-defined function."""

def __init__(self, name, node, mem, cmd_ev):
# type: (str, command.Func, state.Mem, CommandEvaluator) -> None
self.name = name
self.node = node
self.cmd_ev = cmd_ev
self.mem = mem

def Call(self, pos_args, named_args):
# type: (List[value_t], Dict[str, value_t]) -> value_t
nargs = len(pos_args)
expected = len(self.node.pos_params)
if self.node.pos_splat:
if nargs < expected:
raise error.TypeErrVerbose(
"%s() expects at least %d arguments but %d were given" %
(self.name, expected, nargs), self.node.keyword)
elif nargs != expected:
raise error.TypeErrVerbose(
"%s() expects %d arguments but %d were given" %
(self.name, expected, nargs), self.node.keyword)

nargs = len(named_args)
expected = len(self.node.named_params)
if nargs != expected:
raise error.TypeErrVerbose(
"%s() expects %d named arguments but %d were given" %
(self.name, expected, nargs), self.node.keyword)

with state.ctx_FuncCall(self.cmd_ev.mem, self):
nargs = len(self.node.pos_params)
for i in xrange(0, nargs):
pos_arg = pos_args[i]
pos_param = self.node.pos_params[i]

arg_name = location.LName(lexer.TokenVal(pos_param.name))
self.mem.SetValue(arg_name, pos_arg, scope_e.LocalOnly)

if self.node.pos_splat:
other_args = value.List(pos_args[nargs:])
arg_name = location.LName(lexer.TokenVal(self.node.pos_splat))
self.mem.SetValue(arg_name, other_args, scope_e.LocalOnly)

# TODO: pass named args

try:
self.cmd_ev._Execute(self.node.body)

return value.Null # implicit return
except vm.ValueControlFlow as e:
return e.value
except vm.IntControlFlow as e:
raise AssertionError('IntControlFlow in func')

raise AssertionError('unreachable')


class ctx_LoopLevel(object):
"""For checking for invalid control flow."""

Expand Down Expand Up @@ -1521,7 +1463,7 @@ def _Dispatch(self, node, cmd_st):
# Needed in case the func is an existing variable name
self.mem.SetTokenForLine(node.name)

val = value.Func(Func(name, node, self.mem, self))
val = value.Func(code.UserFunc(name, node, self.mem, self))
self.mem.SetValue(lval, val, scope_e.LocalOnly,
_PackFlags(Id.KW_Func, state.SetReadOnly))

Expand Down Expand Up @@ -1999,7 +1941,7 @@ def _MaybeRunErrTrap(self):

def RunProc(self, proc, argv, arg0_loc):
# type: (Proc, List[str], loc_t) -> int
"""Run a shell "functions".
"""Run procs aka "shell functions".
For SimpleCommand and registered completion hooks.
"""
Expand All @@ -2010,67 +1952,11 @@ def RunProc(self, proc, argv, arg0_loc):
else:
proc_argv = argv

# Hm this sets "$@". TODO: Set ARGV only
with state.ctx_ProcCall(self.mem, self.mutable_opts, proc, proc_argv):
n_args = len(argv)
UP_sig = sig

if UP_sig.tag() == proc_sig_e.Closed: # proc is-closed ()
sig = cast(proc_sig.Closed, UP_sig)
for i, p in enumerate(sig.words):

# proc p(out Ref)
is_out_param = (p.type is not None and p.type.name == 'Ref')

param_name = p.name.tval
if i < n_args:
arg_str = argv[i]

# If we have myproc(p), and call it with myproc :arg, then bind
# __p to 'arg'. That is, the param has a prefix ADDED, and the arg
# has a prefix REMOVED.
#
# This helps eliminate "nameref cycles".
if is_out_param:
param_name = '__' + param_name

if not arg_str.startswith(':'):
# TODO: Point to the exact argument
e_die(
'Invalid argument %r. Expected a name starting with :'
% arg_str)
arg_str = arg_str[1:]

val = value.Str(arg_str) # type: value_t
else:
val = proc.defaults[i]
if val is None:
e_die("No value provided for param %r" %
p.name.tval)

if is_out_param:
flags = state.SetNameref
else:
flags = 0

self.mem.SetValue(location.LName(param_name),
val,
scope_e.LocalOnly,
flags=flags)

n_params = len(sig.words)
if sig.rest_words:
items = [value.Str(s)
for s in argv[n_params:]] # type: List[value_t]
leftover = value.List(items)
self.mem.SetValue(location.LName(sig.rest_words.tval), leftover,
scope_e.LocalOnly)
else:
if n_args > n_params:
self.errfmt.Print_(
"proc %r expected %d arguments, but got %d" %
(proc.name, n_params, n_args), arg0_loc)
# This should be status 2 because it's like a usage error.
return 2
status = code.BindProcArgs(proc, argv, arg0_loc, self.mem, self.errfmt)
if status != 0:
return status

# Redirects still valid for functions.
# Here doc causes a pipe and Process(SubProgramThunk).
Expand Down Expand Up @@ -2129,6 +2015,7 @@ def EvalBlock(self, block):

def RunFuncForCompletion(self, proc, argv):
# type: (Proc, List[str]) -> int

# TODO: Change this to run YSH procs and funcs too
try:
status = self.RunProc(proc, argv, loc.Missing)
Expand All @@ -2145,5 +2032,4 @@ def RunFuncForCompletion(self, proc, argv):
# NOTE: (IOError, OSError) are caught in completion.py:ReadlineCallback
return status


# vim: sw=4
5 changes: 5 additions & 0 deletions ysh/expr_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,9 @@ def _EvalArgListUntyped(self, args):
def _EvalArgList(self, args, me=None):
# type: (ArgList, Optional[value_t]) -> Tuple[List[value_t], Dict[str, value_t]]
"""For procs and args - TYPED """

# TODO: CommandEvaluator.RunProc is similar

pos_args = [] # type: List[value_t]

if me: # self/this argument
Expand Down Expand Up @@ -784,6 +787,8 @@ def _EvalFuncCall(self, node):
if mylib.PYTHON:
f = func.callable
if isinstance(f, vm._Callable): # typed
# TODO: consider using typed_args.Reader

pos_args, named_args = self._EvalArgList(node.args)
#log('pos_args %s', pos_args)

Expand Down

0 comments on commit 185bad8

Please sign in to comment.