Skip to content

Commit

Permalink
Merge pull request #2524 from scauligi/async_gen_expr
Browse files Browse the repository at this point in the history
Async generators no longer implicitly return final expressions
  • Loading branch information
scauligi committed Oct 31, 2023
2 parents c08378f + dd1ba62 commit 47ab79f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 12 deletions.
1 change: 1 addition & 0 deletions NEWS.rst
Expand Up @@ -60,6 +60,7 @@ Bug Fixes
already did.
* `nonlocal` now works for top-level `let`-bound names.
* `hy -i` with a filename now skips shebang lines.
* Implicit returns are now disabled in async generators.

0.27.0 (released 2023-07-06)
=============================
Expand Down
5 changes: 4 additions & 1 deletion docs/api.rst
Expand Up @@ -113,7 +113,10 @@ base names, such that ``hy.core.macros.foo`` can be called as just ``foo``.
forms, and the first of them is a string literal, this string becomes the
:term:`py:docstring` of the function. The final body form is implicitly
returned; thus, ``(defn f [] 5)`` is equivalent to ``(defn f [] (return
5))``.
5))``. There is one exception: due to Python limitations, no implicit return
is added if the function is an asynchronous generator (i.e., defined with
:hy:func:`defn/a` or :hy:func:`fn/a` and containing at least one
:hy:func:`yield` or :hy:func:`yield-from`).

``defn`` accepts a few more optional arguments: a bracketed list of
:term:`decorators <py:decorator>`, a list of type parameters (see below),
Expand Down
31 changes: 21 additions & 10 deletions hy/core/result_macros.py
Expand Up @@ -58,7 +58,7 @@
is_unpack,
)
from hy.reader import mangle
from hy.scoping import OuterVar, ScopeFn, ScopeGen, ScopeLet, is_inside_function_scope
from hy.scoping import OuterVar, ScopeFn, ScopeGen, ScopeLet, is_function_scope, is_inside_function_scope, nearest_python_scope

# ------------------------------------------------
# * Helpers
Expand Down Expand Up @@ -1457,24 +1457,27 @@ def compile_try_expression(compiler, expr, root, body, catchers, orelse, finalbo
@pattern_macro(["fn", "fn/a"],
[maybe(type_params), maybe_annotated(lambda_list), many(FORM)])
def compile_function_lambda(compiler, expr, root, tp, params, body):
is_async = root == "fn/a"
params, returns = params
posonly, args, rest, kwonly, kwargs = params
has_annotations = returns is not None or any(
isinstance(param, tuple) and param[1] is not None
for param in (posonly or []) + args + kwonly + [rest, kwargs]
)
args, ret = compile_lambda_list(compiler, params)
with compiler.local_state(), compiler.scope.create(ScopeFn, args):
with compiler.local_state(), compiler.scope.create(ScopeFn, args, is_async) as scope:
body = compiler._compile_branch(body)

# Compile to lambda if we can
if not (has_annotations or tp or body.stmts or root == "fn/a"):
if not (has_annotations or tp or body.stmts or is_async):
return ret + asty.Lambda(expr, args=args, body=body.force_expr)

# Otherwise create a standard function
node = asty.AsyncFunctionDef if root == "fn/a" else asty.FunctionDef
node = asty.AsyncFunctionDef if is_async else asty.FunctionDef
name = compiler.get_anon_var()
ret += compile_function_node(compiler, expr, node, [], tp, name, args, returns, body)
ret += compile_function_node(
compiler, expr, node, [], tp, name, args, returns, body, scope
)

# return its name as the final expr
return ret + Result(expr=ret.temp_variables[0])
Expand All @@ -1485,26 +1488,30 @@ def compile_function_lambda(compiler, expr, root, tp, params, body):
[maybe(brackets(many(FORM))), maybe(type_params), maybe_annotated(SYM), lambda_list, many(FORM)],
)
def compile_function_def(compiler, expr, root, decorators, tp, name, params, body):
is_async = root == "defn/a"
name, returns = name
node = asty.FunctionDef if root == "defn" else asty.AsyncFunctionDef
node = asty.AsyncFunctionDef if is_async else asty.FunctionDef
decorators, ret, _ = compiler._compile_collect(decorators[0] if decorators else [])
args, ret2 = compile_lambda_list(compiler, params)
ret += ret2
name = mangle(compiler._nonconst(name))
compiler.scope.define(name)
with compiler.local_state(), compiler.scope.create(ScopeFn, args):
with compiler.local_state(), compiler.scope.create(ScopeFn, args, is_async) as scope:
body = compiler._compile_branch(body)

return ret + compile_function_node(
compiler, expr, node, decorators, tp, name, args, returns, body
compiler, expr, node, decorators, tp, name, args, returns, body, scope
)


def compile_function_node(compiler, expr, node, decorators, tp, name, args, returns, body):
def compile_function_node(compiler, expr, node, decorators, tp, name, args, returns, body, scope):
ret = Result()

if body.expr:
body += asty.Return(body.expr, value=body.expr)
# implicitly return final expression,
# except for async generators
enode = asty.Expr if scope.is_async and scope.has_yield else asty.Return
body += enode(body.expr, value=body.expr)

ret += node(
expr,
Expand Down Expand Up @@ -1665,6 +1672,8 @@ def compile_return(compiler, expr, root, arg):

@pattern_macro("yield", [maybe(FORM)])
def compile_yield_expression(compiler, expr, root, arg):
if is_inside_function_scope(compiler.scope):
nearest_python_scope(compiler.scope).has_yield = True
ret = Result()
if arg is not None:
ret += compiler.compile(arg)
Expand All @@ -1673,6 +1682,8 @@ def compile_yield_expression(compiler, expr, root, arg):

@pattern_macro(["yield-from", "await"], [FORM])
def compile_yield_from_or_await_expression(compiler, expr, root, arg):
if root == "yield-from" and is_inside_function_scope(compiler.scope):
nearest_python_scope(compiler.scope).has_yield = True
ret = Result() + compiler.compile(arg)
node = asty.YieldFrom if root == "yield-from" else asty.Await
return ret + node(expr, value=ret.force_expr)
Expand Down
8 changes: 7 additions & 1 deletion hy/scoping.py
Expand Up @@ -292,7 +292,7 @@ def add(self, target, new_name=None):
class ScopeFn(ScopeBase):
"""Scope that corresponds to Python's own function or class scopes."""

def __init__(self, compiler, args=None):
def __init__(self, compiler, args=None, is_async=False):
super().__init__(compiler)
self.defined = set()
"set: of all vars defined in this scope"
Expand All @@ -305,6 +305,12 @@ def __init__(self, compiler, args=None):
bool: `True` if this scope is being used to track a python
function `False` for classes
"""
self.is_async = is_async
"""bool: `True` if this scope is for an async function,
which may need special handling during compilation"""
self.has_yield = False
"""bool: `True` if this scope is tracking a function that has `yield`
statements, as generator functions may need special handling"""

if args:
for arg in itertools.chain(
Expand Down
13 changes: 13 additions & 0 deletions tests/native_tests/functions.hy
Expand Up @@ -189,6 +189,19 @@
(assert (= (asyncio.run (coro-test)) [1 2 3])))


(defn [async-test] test-no-async-gen-return []
; https://github.com/hylang/hy/issues/2523
(defn/a runner [gen]
(setv vals [])
(for [:async val (gen)]
(.append vals val))
vals)
(defn/a naysayer []
(yield "nope"))
(assert (= (asyncio.run (runner naysayer)) ["nope"]))
(assert (= (asyncio.run (runner (fn/a [] (yield "dope!")) ["dope!"])))))


(defn test-root-set-correctly []
; https://github.com/hylang/hy/issues/2475
((. defn) not-async [] "ok")
Expand Down

0 comments on commit 47ab79f

Please sign in to comment.