Skip to content

Commit

Permalink
Use new arg name ellision syntax
Browse files Browse the repository at this point in the history
Resolves   #811.
  • Loading branch information
evhub committed Apr 20, 2024
1 parent 878aacf commit 33ff96a
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 25 deletions.
8 changes: 5 additions & 3 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ quad = 5 * x**2 + 3 * x + 1

When passing in long variable names as keyword arguments of the same name, Coconut supports the syntax
```
f(...=long_variable_name)
f(long_variable_name=)
```
as a shorthand for
```
Expand All @@ -2228,15 +2228,17 @@ f(long_variable_name=long_variable_name)

Such syntax is also supported in [partial application](#partial-application) and [anonymous `namedtuple`s](#anonymous-namedtuples).

_Deprecated: Coconut also supports `f(...=long_variable_name)` as an alternative shorthand syntax._

##### Example

**Coconut:**
```coconut
really_long_variable_name_1 = get_1()
really_long_variable_name_2 = get_2()
main_func(
...=really_long_variable_name_1,
...=really_long_variable_name_2,
really_long_variable_name_1=,
really_long_variable_name_2=,
)
```

Expand Down
32 changes: 22 additions & 10 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def reset(self, keep_state=False, filename=None):
self.add_code_before_replacements = {}
self.add_code_before_ignore_names = {}
self.remaining_original = None
self.shown_warnings = set()

@contextmanager
def inner_environment(self, ln=None):
Expand All @@ -556,6 +557,7 @@ def inner_environment(self, ln=None):
kept_lines, self.kept_lines = self.kept_lines, []
num_lines, self.num_lines = self.num_lines, 0
remaining_original, self.remaining_original = self.remaining_original, None
shown_warnings, self.shown_warnings = self.shown_warnings, set()
try:
with ComputationNode.using_overrides():
yield
Expand All @@ -571,6 +573,7 @@ def inner_environment(self, ln=None):
self.kept_lines = kept_lines
self.num_lines = num_lines
self.remaining_original = remaining_original
self.shown_warnings = shown_warnings

@contextmanager
def disable_checks(self):
Expand Down Expand Up @@ -937,11 +940,14 @@ def strict_err(self, *args, **kwargs):
if self.strict:
raise self.make_err(CoconutStyleError, *args, **kwargs)

def syntax_warning(self, *args, **kwargs):
def syntax_warning(self, message, original, loc, **kwargs):
"""Show a CoconutSyntaxWarning. Usage:
self.syntax_warning(message, original, loc)
"""
logger.warn_err(self.make_err(CoconutSyntaxWarning, *args, **kwargs))
key = (message, loc)
if key not in self.shown_warnings:
logger.warn_err(self.make_err(CoconutSyntaxWarning, message, original, loc, **kwargs))
self.shown_warnings.add(key)

def strict_err_or_warn(self, *args, **kwargs):
"""Raises an error if in strict mode, otherwise raises a warning. Usage:
Expand Down Expand Up @@ -2779,7 +2785,7 @@ def polish(self, inputstring, final_endline=True, **kwargs):
# HANDLERS:
# -----------------------------------------------------------------------------------------------------------------------

def split_function_call(self, tokens, loc):
def split_function_call(self, original, loc, tokens):
"""Split into positional arguments and keyword arguments."""
pos_args = []
star_args = []
Expand All @@ -2802,7 +2808,10 @@ def split_function_call(self, tokens, loc):
star_args.append(argstr)
elif arg[0] == "**":
dubstar_args.append(argstr)
elif arg[1] == "=":
kwd_args.append(arg[0] + "=" + arg[0])
elif arg[0] == "...":
self.strict_err_or_warn("'...={name}' shorthand is deprecated, use '{name}=' shorthand instead".format(name=arg[1]), original, loc)
kwd_args.append(arg[1] + "=" + arg[1])
else:
kwd_args.append(argstr)
Expand All @@ -2818,9 +2827,9 @@ def split_function_call(self, tokens, loc):

return pos_args, star_args, kwd_args, dubstar_args

def function_call_handle(self, loc, tokens):
def function_call_handle(self, original, loc, tokens):
"""Enforce properly ordered function parameters."""
return "(" + join_args(*self.split_function_call(tokens, loc)) + ")"
return "(" + join_args(*self.split_function_call(original, loc, tokens)) + ")"

def pipe_item_split(self, original, loc, tokens):
"""Process a pipe item, which could be a partial, an attribute access, a method call, or an expression.
Expand All @@ -2841,7 +2850,7 @@ def pipe_item_split(self, original, loc, tokens):
return "expr", tokens
elif "partial" in tokens:
func, args = tokens
pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(args, loc)
pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(original, loc, args)
return "partial", (func, join_args(pos_args, star_args), join_args(kwd_args, dubstar_args))
elif "attrgetter" in tokens:
name, args = attrgetter_atom_split(tokens)
Expand Down Expand Up @@ -3061,7 +3070,7 @@ def item_handle(self, original, loc, tokens):
elif trailer[0] == "$[":
out = "_coconut_iter_getitem(" + out + ", " + trailer[1] + ")"
elif trailer[0] == "$(?":
pos_args, star_args, base_kwd_args, dubstar_args = self.split_function_call(trailer[1], loc)
pos_args, star_args, base_kwd_args, dubstar_args = self.split_function_call(original, loc, trailer[1])

has_question_mark = False
needs_complex_partial = False
Expand Down Expand Up @@ -3232,7 +3241,7 @@ def classdef_handle(self, original, loc, tokens):
# handle classlist
base_classes = []
if classlist_toks:
pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(classlist_toks, loc)
pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(original, loc, classlist_toks)

# check for just inheriting from object
if (
Expand Down Expand Up @@ -3566,7 +3575,7 @@ def __hash__(self):

return "".join(out)

def anon_namedtuple_handle(self, tokens):
def anon_namedtuple_handle(self, original, loc, tokens):
"""Handle anonymous named tuples."""
names = []
types = {}
Expand All @@ -3579,7 +3588,10 @@ def anon_namedtuple_handle(self, tokens):
types[i] = typedef
else:
raise CoconutInternalException("invalid anonymous named item", tok)
if name == "...":
if item == "=":
item = name
elif name == "...":
self.strict_err_or_warn("'...={item}' shorthand is deprecated, use '{item}=' shorthand instead".format(item=item), original, loc)
name = item
names.append(name)
items.append(item)
Expand Down
9 changes: 5 additions & 4 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,11 +1249,12 @@ class Grammar(object):

call_item = (
unsafe_name + default
# ellipsis must come before namedexpr_test
| ellipsis_tokens + equals.suppress() + refname
| namedexpr_test
| star + test
| dubstar + test
| refname + equals # new long name ellision syntax
| ellipsis_tokens + equals.suppress() + refname # old long name ellision syntax
# must come at end
| namedexpr_test
)
function_call_tokens = lparen.suppress() + (
# everything here must end with rparen
Expand Down Expand Up @@ -1303,7 +1304,7 @@ class Grammar(object):
maybe_typedef = Optional(colon.suppress() + typedef_test)
anon_namedtuple_ref = tokenlist(
Group(
unsafe_name + maybe_typedef + equals.suppress() + test
unsafe_name + maybe_typedef + (equals.suppress() + test | equals)
| ellipsis_tokens + maybe_typedef + equals.suppress() + refname
),
comma,
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.1.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 9
DEVELOP = 10
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
8 changes: 4 additions & 4 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ def primary_test_2() -> bool:
f.is_f = True # type: ignore
assert (f ..*> (+)).is_f # type: ignore
really_long_var = 10
assert (...=really_long_var) == (10,)
assert (...=really_long_var, abc="abc") == (10, "abc")
assert (abc="abc", ...=really_long_var) == ("abc", 10)
assert (...=really_long_var).really_long_var == 10 # type: ignore
assert (really_long_var=) == (10,)
assert (really_long_var=, abc="abc") == (10, "abc")
assert (abc="abc", really_long_var=) == ("abc", 10)
assert (really_long_var=).really_long_var == 10 # type: ignore
n = [0]
assert n[0] == 0
assert_raises(-> m{{1:2,2:3}}, TypeError)
Expand Down
4 changes: 2 additions & 2 deletions coconut/tests/src/cocotest/agnostic/suite.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1055,8 +1055,8 @@ forward 2""") == 900
assert InitAndIter(range(3)) |> fmap$((.+1), fallback_to_init=True) == InitAndIter(range(1, 4))
assert_raises(-> InitAndIter(range(3)) |> fmap$(.+1), TypeError)
really_long_var = 10
assert ret_args_kwargs(...=really_long_var) == ((), {"really_long_var": 10}) == ret_args_kwargs$(...=really_long_var)()
assert ret_args_kwargs(123, ...=really_long_var, abc="abc") == ((123,), {"really_long_var": 10, "abc": "abc"}) == ret_args_kwargs$(123, ...=really_long_var, abc="abc")()
assert ret_args_kwargs(really_long_var=) == ((), {"really_long_var": 10}) == ret_args_kwargs$(really_long_var=)()
assert ret_args_kwargs(123, really_long_var=, abc="abc") == ((123,), {"really_long_var": 10, "abc": "abc"}) == ret_args_kwargs$(123, really_long_var=, abc="abc")()
assert "Coconut version of typing" in typing.__doc__
numlist: NumList = [1, 2.3, 5]
assert hasloc([[1, 2]]).loc[0][1] == 2 == hasloc([[1, 2]]) |> .loc[0][1]
Expand Down
2 changes: 1 addition & 1 deletion coconut/tests/src/cocotest/agnostic/util.coco
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ match def fact(n) = fact(n, 1)
match addpattern def fact(0, acc) = acc # type: ignore
addpattern match def fact(n, acc) = fact(n-1, acc*n) # type: ignore

addpattern def factorial(0, acc=1) = acc
match def factorial(0, acc=1) = acc
addpattern def factorial(int() as n, acc=1 if n > 0) = # type: ignore
"""this is a docstring"""
factorial(n-1, acc*n)
Expand Down
5 changes: 5 additions & 0 deletions coconut/tests/src/cocotest/non_strict/non_strict_test.coco
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def non_strict_test() -> bool:
@recursive_iterator
def fib() = (1, 1) :: map((+), fib(), fib()$[1:])
assert fib()$[:5] |> list == [1, 1, 2, 3, 5]
addpattern def args_or_kwargs(*args) = args
addpattern def args_or_kwargs(**kwargs) = kwargs # type: ignore
assert args_or_kwargs(1, 2) == (1, 2)
very_long_name = 10
assert args_or_kwargs(short_name=5, very_long_name=) == {"short_name": 5, "very_long_name": 10}
return True

if __name__ == "__main__":
Expand Down

0 comments on commit 33ff96a

Please sign in to comment.