Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async generators no longer implicitly return final expressions #2524

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Bug Fixes
already did.
* `nonlocal` now works for top-level `let`-bound names.
* `hy -i` with a filename now skips shebang lines.
* Implicit returns are now disabled in async generators.

0.27.0 (released 2023-07-06)
=============================
Expand Down
5 changes: 4 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ base names, such that ``hy.core.macros.foo`` can be called as just ``foo``.
forms, and the first of them is a string literal, this string becomes the
:term:`py:docstring` of the function. The final body form is implicitly
returned; thus, ``(defn f [] 5)`` is equivalent to ``(defn f [] (return
5))``.
5))``. There is one exception: due to Python limitations, no implicit return
is added if the function is an asynchronous generator (i.e., defined with
:hy:func:`defn/a` or :hy:func:`fn/a` and containing at least one
:hy:func:`yield` or :hy:func:`yield-from`).

``defn`` accepts a few more optional arguments: a bracketed list of
:term:`decorators <py:decorator>`, a list of type parameters (see below),
Expand Down
31 changes: 21 additions & 10 deletions hy/core/result_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
is_unpack,
)
from hy.reader import mangle
from hy.scoping import OuterVar, ScopeFn, ScopeGen, ScopeLet, is_inside_function_scope
from hy.scoping import OuterVar, ScopeFn, ScopeGen, ScopeLet, is_function_scope, is_inside_function_scope, nearest_python_scope

# ------------------------------------------------
# * Helpers
Expand Down Expand Up @@ -1457,24 +1457,27 @@ def compile_try_expression(compiler, expr, root, body, catchers, orelse, finalbo
@pattern_macro(["fn", "fn/a"],
[maybe(type_params), maybe_annotated(lambda_list), many(FORM)])
def compile_function_lambda(compiler, expr, root, tp, params, body):
is_async = root == "fn/a"
params, returns = params
posonly, args, rest, kwonly, kwargs = params
has_annotations = returns is not None or any(
isinstance(param, tuple) and param[1] is not None
for param in (posonly or []) + args + kwonly + [rest, kwargs]
)
args, ret = compile_lambda_list(compiler, params)
with compiler.local_state(), compiler.scope.create(ScopeFn, args):
with compiler.local_state(), compiler.scope.create(ScopeFn, args, is_async) as scope:
body = compiler._compile_branch(body)

# Compile to lambda if we can
if not (has_annotations or tp or body.stmts or root == "fn/a"):
if not (has_annotations or tp or body.stmts or is_async):
return ret + asty.Lambda(expr, args=args, body=body.force_expr)

# Otherwise create a standard function
node = asty.AsyncFunctionDef if root == "fn/a" else asty.FunctionDef
node = asty.AsyncFunctionDef if is_async else asty.FunctionDef
name = compiler.get_anon_var()
ret += compile_function_node(compiler, expr, node, [], tp, name, args, returns, body)
ret += compile_function_node(
compiler, expr, node, [], tp, name, args, returns, body, scope
)

# return its name as the final expr
return ret + Result(expr=ret.temp_variables[0])
Expand All @@ -1485,26 +1488,30 @@ def compile_function_lambda(compiler, expr, root, tp, params, body):
[maybe(brackets(many(FORM))), maybe(type_params), maybe_annotated(SYM), lambda_list, many(FORM)],
)
def compile_function_def(compiler, expr, root, decorators, tp, name, params, body):
is_async = root == "defn/a"
name, returns = name
node = asty.FunctionDef if root == "defn" else asty.AsyncFunctionDef
node = asty.AsyncFunctionDef if is_async else asty.FunctionDef
decorators, ret, _ = compiler._compile_collect(decorators[0] if decorators else [])
args, ret2 = compile_lambda_list(compiler, params)
ret += ret2
name = mangle(compiler._nonconst(name))
compiler.scope.define(name)
with compiler.local_state(), compiler.scope.create(ScopeFn, args):
with compiler.local_state(), compiler.scope.create(ScopeFn, args, is_async) as scope:
body = compiler._compile_branch(body)

return ret + compile_function_node(
compiler, expr, node, decorators, tp, name, args, returns, body
compiler, expr, node, decorators, tp, name, args, returns, body, scope
)


def compile_function_node(compiler, expr, node, decorators, tp, name, args, returns, body):
def compile_function_node(compiler, expr, node, decorators, tp, name, args, returns, body, scope):
ret = Result()

if body.expr:
body += asty.Return(body.expr, value=body.expr)
# implicitly return final expression,
# except for async generators
enode = asty.Expr if scope.is_async and scope.has_yield else asty.Return
body += enode(body.expr, value=body.expr)

ret += node(
expr,
Expand Down Expand Up @@ -1665,6 +1672,8 @@ def compile_return(compiler, expr, root, arg):

@pattern_macro("yield", [maybe(FORM)])
def compile_yield_expression(compiler, expr, root, arg):
if is_inside_function_scope(compiler.scope):
nearest_python_scope(compiler.scope).has_yield = True
ret = Result()
if arg is not None:
ret += compiler.compile(arg)
Expand All @@ -1673,6 +1682,8 @@ def compile_yield_expression(compiler, expr, root, arg):

@pattern_macro(["yield-from", "await"], [FORM])
def compile_yield_from_or_await_expression(compiler, expr, root, arg):
if root == "yield-from" and is_inside_function_scope(compiler.scope):
nearest_python_scope(compiler.scope).has_yield = True
ret = Result() + compiler.compile(arg)
node = asty.YieldFrom if root == "yield-from" else asty.Await
return ret + node(expr, value=ret.force_expr)
Expand Down
8 changes: 7 additions & 1 deletion hy/scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def add(self, target, new_name=None):
class ScopeFn(ScopeBase):
"""Scope that corresponds to Python's own function or class scopes."""

def __init__(self, compiler, args=None):
def __init__(self, compiler, args=None, is_async=False):
super().__init__(compiler)
self.defined = set()
"set: of all vars defined in this scope"
Expand All @@ -305,6 +305,12 @@ def __init__(self, compiler, args=None):
bool: `True` if this scope is being used to track a python
function `False` for classes
"""
self.is_async = is_async
"""bool: `True` if this scope is for an async function,
which may need special handling during compilation"""
self.has_yield = False
"""bool: `True` if this scope is tracking a function that has `yield`
statements, as generator functions may need special handling"""

if args:
for arg in itertools.chain(
Expand Down
13 changes: 13 additions & 0 deletions tests/native_tests/functions.hy
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@
(assert (= (asyncio.run (coro-test)) [1 2 3])))


(defn [async-test] test-no-async-gen-return []
; https://github.com/hylang/hy/issues/2523
(defn/a runner [gen]
(setv vals [])
(for [:async val (gen)]
(.append vals val))
vals)
(defn/a naysayer []
(yield "nope"))
(assert (= (asyncio.run (runner naysayer)) ["nope"]))
(assert (= (asyncio.run (runner (fn/a [] (yield "dope!")) ["dope!"])))))


(defn test-root-set-correctly []
; https://github.com/hylang/hy/issues/2475
((. defn) not-async [] "ok")
Expand Down