Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper tail recursion #728

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions docs/contrib/future.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
==========
__future__ features
==========

.. versionadded:: 0.10.2

Importing from ``__future__`` allows you to add features to
Hy that are not yet in the main language, due to slowing or
being harder to debug.


.. _tailrec:


``TailRec``
===========

The ``(import [__future__ [TailRec]])`` command
gives programmers a simple way to use tail-call optimization
(TCO) in their Hy code, with no need for trampoline or recur.
Supports mutually recursive functions.

A tail call is a subroutine call that happens inside another
procedure as its final action; it may produce a return value which
is then immediately returned by the calling procedure. If any call
that a subroutine performs, such that it might eventually lead to
this same subroutine being called again down the call chain, is in
tail position, such a subroutine is said to be tail-recursive,
which is a special case of recursion. Tail calls are significant
because they can be implemented without adding a new stack frame
to the call stack. Most of the frame of the current procedure is
not needed any more, and it can be replaced by the frame of the
tail call. The program can then jump to the called
subroutine. Producing such code instead of a standard call
sequence is called tail call elimination, or tail call
optimization. Tail call elimination allows procedure calls in tail
position to be implemented as efficiently as goto statements, thus
allowing efficient structured programming.

-- Wikipedia (http://en.wikipedia.org/wiki/Tail_call)

Example:

.. code-block:: hy

(import [__future__ [TailRec]])

(defn fact [n]
(defn facthelper [n acc]
(if (= n 0)
acc
(facthelper (- n 1) (* n acc))))
(do
(print "Using fact!")
(facthelper n 1)))

(print (fact 10000))

(defn odd [n]
(if (= n 0)
False
(even (- n 1))))

(defn even [n]
(if (= n 0)
True
(odd (- n 1))))

(print (even 1000))

1 change: 1 addition & 0 deletions docs/contrib/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Contents:
anaphoric
loop
multi
future
46 changes: 46 additions & 0 deletions hy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ class HyASTCompiler(object):
def __init__(self, module_name):
self.anon_fn_count = 0
self.anon_var_count = 0
self.tail_rec = False
self.imports = defaultdict(set)
self.module_name = module_name
if not module_name.startswith("hy.core"):
Expand Down Expand Up @@ -1108,6 +1109,12 @@ def _compile_import(expr, module, names=None, importer=ast.Import):
while len(expr) > 0:
iexpr = expr.pop(0)

if iexpr[0] == "__future__" and "TailRec" in iexpr[1]:
self.tail_rec = True
TRInd = iexpr[1].index("TailRec")
iexpr[1].pop(TRInd)
if(len(iexpr[1]) == 0):
continue
if not isinstance(iexpr, (HySymbol, HyList)):
raise HyTypeError(iexpr, "(import) requires a Symbol "
"or a List.")
Expand Down Expand Up @@ -1976,6 +1983,17 @@ def compile_function_def(self, expression):
if body.contains_yield:
body += body.expr_as_stmt()
else:
if self.tail_rec:
expression[-1], changed = self.makeTailRec(expression[-1])
if changed:
new_expression = HyExpression([
HySymbol("with_decorator"),
HySymbol("HyTailRec"),
HyExpression([
called_as,
arglist] + expression)
]).replace(expression)
return self.compile(new_expression)
body += ast.Return(value=body.expr,
lineno=body.expr.lineno,
col_offset=body.expr.col_offset)
Expand Down Expand Up @@ -2003,6 +2021,34 @@ def compile_function_def(self, expression):

return ret

def exprToTailCall(self, expr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use snake-casting rather than camel casing :)

"Takes an expression, and returns a TailCall of that expression"
return HyExpression([
HySymbol("raise"),
HyExpression([HySymbol("HyTailCall")] + expr),
]).replace(expr)

def makeTailRec(self, body):
""" Takes the body of an expression, and returns a tail recursive
TailCall form of the body """
if isinstance(body, HyExpression):
# Only compile expression, names and symbols should just stand
if body[0] == "if":
body[2], changed2 = self.makeTailRec(body[2])
changed3 = False
if len(body) == 4:
body[3], changed3 = self.makeTailRec(body[3])
return body, (changed2 or changed3)
if body[0] == "progn" or body[0] == "do":
body[-1], changed = self.makeTailRec(body[-1])
return body, changed
elif body[0] in _compile_table.keys():
# Bail on all keywords in hy and other special forms
return body, False
else:
return self.exprToTailCall(body), True
return body, False

@builds("defclass")
@checkargs(min=1)
def compile_class_expression(self, expression):
Expand Down
1 change: 1 addition & 0 deletions hy/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
STDLIB = [
"hy.core.language",
"hy.core.tailrec",
"hy.core.shadow"
]
30 changes: 30 additions & 0 deletions hy/core/tailrec.hy
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

(defclass HyTailCall [Exception]
"An exeception to implement Proper Tail Recursion"
[[--init--
(fn [self __TCFunc &rest args &kwargs kwargs]
(setv self.func __TCFunc)
(setv self.args args)
(setv self.kwargs kwargs)
nil)]])

(defn HyTailRec [func]
"""A decorator that takes functions that end in raise HyTailCall(func, *args, **kwargs)
and makes them tail recursive"""
(if (hasattr func "__nonTCO")
func
(do
(defn funcwrapper [&rest args &kwargs kwargs]
(setv funcwrapper.__nonTCO func)
(setv tc (apply HyTailCall (cons func (list args)) kwargs))
(while True
(try (if (hasattr tc.func "__nonTCO")
(setv ret (apply tc.func.__nonTCO (list tc.args) tc.kwargs))
(setv ret (apply tc.func (list tc.args) tc.kwargs)))
(catch [err HyTailCall]
(setv tc err))
(else (break))))
ret)
funcwrapper)))

(setv *exports* '[HyTailCall HyTailRec])
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .native_tests.contrib.walk import * # noqa
from .native_tests.contrib.multi import * # noqa
from .native_tests.contrib.curry import * # noqa
from .native_tests.tailrec import * # noqa

if PY3:
from .native_tests.py3_only_tests import * # noqa
19 changes: 19 additions & 0 deletions tests/native_tests/tailrec.hy
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
(import [__future__ [TailRec]])

(defn test-mutualtailrec []
"Testing whether tail recursion in mutually recursive functions work"
(do
(defn tcodd [n] (if (= n 0) False (tceven (- n 1))))
(defn tceven [n] (if (= n 0) True (tcodd (- n 1))))
(assert (tceven 1000))))

(defn test-selfrecur []
"Testing whether tail recusion in self recursive functions work"
(do
(defn fact [n]
(defn facthelper [n acc]
(if (= n 0)
acc
(facthelper (- n 1) (* n acc))))
(facthelper n 1))
(assert (< 0 (fact 1000)))))