Skip to content

Commit

Permalink
Fix #649 Fix captures inside tail-rec optimized functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Jand42 committed Feb 16, 2017
1 parent b0787a6 commit 5d358bb
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 27 deletions.
19 changes: 19 additions & 0 deletions src/compiler/WebSharper.Compiler.FSharp/CodeReader.fs
Expand Up @@ -520,6 +520,15 @@ let curriedApplication func args =
| 1 -> Application (func, args, false, Some 1)
| _ -> CurriedApplication(func, args)

let (|CompGenLambda|_|) n (expr: FSharpExpr) =
let rec get acc n expr =
if n = 0 then Some (List.rev acc, expr) else
match expr with
| P.Lambda(id, body) when id.IsCompilerGenerated ->
get (id :: acc) (n - 1) body
| _ -> None
get [] n expr

let rec transformExpression (env: Environment) (expr: FSharpExpr) =
let inline tr x = transformExpression env x
let sr = env.SymbolReader
Expand Down Expand Up @@ -562,6 +571,16 @@ let rec transformExpression (env: Environment) (expr: FSharpExpr) =
let trBody = body |> transformExpression env
trBody |> List.foldBack (fun v e -> lam [v] e (obj.ReferenceEquals(trBody, isUnit) && isUnit body.Type)) vars
| P.Application(func, types, args) ->
match func with
| CompGenLambda args.Length (ids, body) ->
let vars, env =
(env, ids) ||> List.mapFold (fun env arg ->
let v = namedId arg
v, env.WithVar(v, arg)
)
let inline tr x = transformExpression env x
List.foldBack2 (fun i v b -> Let(i, tr v, b)) vars args (tr body)
| _ ->
match IgnoreExprSourcePos (tr func) with
| CallNeedingMoreArgs(thisObj, td, m, ca) ->
Call(thisObj, td, m, ca @ (args |> List.map tr))
Expand Down
5 changes: 2 additions & 3 deletions src/compiler/WebSharper.Compiler.FSharp/ProjectReader.fs
Expand Up @@ -495,9 +495,8 @@ let rec private transformClass (sc: Lazy<_ * StartupCode>) (comp: Compilation) (
Lambda(vars, b)
let currentMethod =
match memdef with
| Member.Method (_, m)
| Member.Implementation (_, m)
| Member.Override (_, m) -> Some (def, m)
| Member.Method (_, m) ->
Some (def, m)
| _ -> None
curriedArgs, TailCalls.optimize currentMethod inlinesOfClass res
with e ->
Expand Down
86 changes: 70 additions & 16 deletions src/compiler/WebSharper.Compiler.FSharp/TailCalls.fs
Expand Up @@ -227,6 +227,57 @@ type TailCallAnalyzer(env) =
env.TailPos <- p
this.VisitExpression h

type AddCapturing(vars : seq<Id>) =
inherit Transformer()

let defined = HashSet(vars)
let captured = HashSet()
let mutable scope = 0

override this.TransformNewVar(var, value) =
if scope = 0 then
defined.Add var |> ignore
base.TransformNewVar(var, value)

override this.TransformVarDeclaration(var, value) =
if scope = 0 then
defined.Add var |> ignore
base.TransformVarDeclaration(var, value)

override this.TransformLet(var, value, body) =
if scope = 0 then
defined.Add var |> ignore
base.TransformLet(var, value, body)

override this.TransformLetRec(defs, body) =
if scope = 0 then
for var, _ in defs do
defined.Add var |> ignore
base.TransformLetRec(defs, body)

override this.TransformId i =
if scope > 0 && defined.Contains i then
captured.Add i |> ignore
i

override this.TransformFunction(args, body) =
scope <- scope + 1
let res =
if scope = 1 then
captured.Clear()
let f = base.TransformFunction(args, body)
if captured.Count > 0 then
let cVars = captured |> List.ofSeq
let cArgs = cVars |> List.map (fun v -> Id.New(?name = v.Name, mut = false))
Application(
Function(cArgs, Return (ReplaceIds(Seq.zip cVars cArgs |> dict).TransformExpression f)),
cVars |> List.map Var, false, None)
else f
else
base.TransformFunction(args, body)
scope <- scope - 1
res

type TailCallTransformer(env) =
inherit Transformer()

Expand All @@ -238,6 +289,22 @@ type TailCallTransformer(env) =
let copying = HashSet<Id>()
let mutable selfCallArgs = None

let withCopiedArgs args b =
let copiedArgs =
args |> Seq.choose (fun a ->
match argCopies.TryGetValue a with
| true, c -> Some c
| _ -> None
) |> List.ofSeq
if List.isEmpty copiedArgs then
b
else
Block [
for a in copiedArgs -> VarDeclaration(a, Undefined)
yield b
]
|> AddCapturing(args).TransformStatement

member this.Recurse(fArgs, origArgs, args: list<_>, index) =
Sequential [
// if recurring with multiple arguments,
Expand Down Expand Up @@ -334,20 +401,6 @@ type TailCallTransformer(env) =
funcCount <- funcCount + 1
| _ -> matchedBindings.Add(var, Choice2Of2 value)
else matchedBindings.Add(var, Choice2Of2 value)
let withCopiedArgs args b =
let copiedArgs =
args |> Seq.choose (fun a ->
match argCopies.TryGetValue a with
| true, c -> Some c
| _ -> None
) |> List.ofSeq
if List.isEmpty copiedArgs then
b
else
Block [
for a in copiedArgs -> VarDeclaration(a, Undefined)
yield b
]
match funcCount with
| 0 -> base.TransformLetRec(List.ofArray bindings, body)
| 1 ->
Expand Down Expand Up @@ -406,7 +459,7 @@ type TailCallTransformer(env) =
recFunc, Function(indexVar :: List.ofSeq newArgs,
While (Value (Bool true),
Switch (Var indexVar, List.ofSeq trBodies))
|> withCopiedArgs newArgs
|> withCopiedArgs newArgs
)
LetRec(mainFunc :: List.ofSeq trBindings, base.TransformExpression body)

Expand All @@ -421,7 +474,8 @@ type TailCallTransformer(env) =
selfCallArgs <- Some args
Function(args,
While (Value (Bool true),
this.TransformStatement(body))
this.TransformStatement(body))
|> withCopiedArgs args
)
else
base.TransformFunction(args, body)
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/WebSharper.Compiler/Breaker.fs
Expand Up @@ -272,7 +272,7 @@ let optimize expr =
else
List.foldBack2 bind vars args (Sequential [body; Value Null])
|> removeLets
| Application(TupledLambda(vars, body, isReturn), [ I.NewArray args ], isPure, _)
| Application(TupledLambda(vars, body, isReturn), [ I.NewArray args ], isPure, Some _)
when vars.Length = args.Length && not (needsScoping vars body) ->
if isReturn then
List.foldBack2 bind vars args body
Expand All @@ -282,11 +282,11 @@ let optimize expr =
| Application(I.ItemGet(I.Function (vars, I.Return body), I.Value (String "apply")), [ I.Value Null; argArr ], isPure, _) ->
List.foldBack2 bind vars (List.init vars.Length (fun i -> argArr.[Value (Int i)])) body
|> removeLets
| Application (I.Function (args, I.Return body), xs, _, _)
| Application (I.Function (args, I.Return body), xs, _, Some _)
when List.length args = List.length xs && not (needsScoping args body) ->
List.foldBack2 bind args xs body
|> removeLets
| Application (I.Function (args, I.ExprStatement body), xs, _, _)
| Application (I.Function (args, I.ExprStatement body), xs, _, Some _)
when List.length args = List.length xs && not (needsScoping args body) ->
List.foldBack2 bind args xs body
|> removeLets
Expand Down
4 changes: 0 additions & 4 deletions src/compiler/WebSharper.Compiler/Translator.fs
Expand Up @@ -585,10 +585,6 @@ type DotNetToJavaScript private (comp: Compilation, ?inProgress) =
let x = Id.New(mut = false)
Lambda ([x], c (Var x :: args) (a - 1))
c [] arity
// match arity with
// | 2 -> JSRuntime.Curried2 f
// | 3 -> JSRuntime.Curried3 f
// | n -> JSRuntime.Curried f n
| TupledFuncArg arity ->
let x = Id.New(mut = false)
let args =
Expand Down
12 changes: 11 additions & 1 deletion tests/WebSharper.Tests/Basis.fs
Expand Up @@ -57,13 +57,19 @@ let private tailRecFactorialTupled n =
| 0 -> acc
| n -> factorial (n * acc, n - 1)
factorial (1, n)


[<JavaScript>]
let private tailRecSingle n =
let rec f n =
if n > 0 then f (n - 1) else 0
f n

[<JavaScript>]
let tailRecScoping n =
let rec f acc n =
if n > 0 then f ((fun () -> n) :: acc) (n - 1) else acc
f [] n

[<JavaScript>]
let private tailRecSingleNoReturn n =
let rec f n =
Expand Down Expand Up @@ -116,6 +122,9 @@ type TailRec() =
let rec classTailRecSingle n =
if n > 0 then classTailRecSingle (n - 1) else 0

let rec classTailRecCurried n m =
if n > 0 then classTailRecCurried (n - 1) (m - 1) else 0

let rec classTailRecSingleUsedInside n =
let mutable setf = fun x -> 0
let rec f n =
Expand Down Expand Up @@ -257,6 +266,7 @@ let Tests =
equalMsg (6 * 5 * 4 * 3 * 2) (tailRecFactorialCurried2 6) "curried tail call with function"
equalMsg (6 * 5 * 4 * 3 * 2) (tailRecFactorialTupled 6) "tupled tail call"
equalMsg 0 (tailRecSingle 5) "single let rec"
equalMsg [1; 2; 3; 4; 5] (tailRecScoping 5 |> List.map (fun f -> f())) "scoping while tail call optimizing"
equalMsg [ 1; 2; 3 ] (tailRecWithMatch [ 3; 2; 1 ]) "single let rec with non-inlined match expression"
equalMsg 1 (tailRecMultiple 5) "mutually recursive let rec"
equalMsg 1 (tailRecWithValue 5) "mutually recursive let rec with a function and a value"
Expand Down

0 comments on commit 5d358bb

Please sign in to comment.