# Optimizing recursive functions

Over the years I've written a lot of F# code. As a personal challenge, I try to work as little as possible with mutable data structures. One area where I always run into problems is when my recursive functions end up being too slow or start blowing the stack.

This post is a mostly *theoretical* deep dive into several techniques to tackle these problems. We'll explore:

* How we can use memoization to speed up recursive functions using a mutable data type
* How we can write an immutable variant of any memoized function
* How we can rewrite any recursive function so it becomes tail recursive and does not blow the stack.

By mostly *theoretical* I mean that while we'll show that it is technically possible to do all this (even combine all of them), you'll soon find out with me that things quickly turn very very complicated and make the original algorithm almost unrecognizable. Now that's not a desirable feature for source code.

We'll explore all these topics and close up with trade-offs I'll be making in the future when I encounter these situations again.

## A simple recursive function

In [1]:
let rec fac n = 
    match n with
    | 0UL -> 1UL
    | _ -> n * (fac (n - 1UL))

In [2]:
fac 15UL

While this technically works, you'll soon find that fac quickly grows too large for integers. Let's rewrite so we can calculate fac for big numbers.

In [3]:
type Ordering = Less | Equal | Greater
module BigInteger =

    open System.Numerics

    let fromInt (i : int) = BigInteger i
    let Zero = fromInt 0
    let One = fromInt 1

    let compare (other : bigint) (one : bigint) =
        match one.CompareTo(other) with
        | 0 -> Equal
        | n when n > 0 -> Greater
        | _ -> Less

    let rec fac (n : bigint) =
        match (n |> compare (BigInteger.Zero)) with
        | Less -> failwithf "negative input: %A" n
        | Equal -> BigInteger.One
        | Greater -> n * (fac (n - BigInteger.One))

In [4]:
BigInteger.fac (BigInteger.fromInt 15)

In [5]:
100 |> BigInteger.fromInt |> BigInteger.fac

This works. Victory! But wait, what happens if we try to calculate a really really big number?

In [6]:
//8_000 |> BigInteger.fromInt |> BigInteger.fac

(* This blows the stack:

    Stack overflow.
    Repeat 7379 times:
    --------------------------------
    at FSI_0003+BigInteger.fac(System.Numerics.BigInteger)

*)

One way to typically fix this is to rewrite the algorithm to a procedural style. "Never use recursion" you might hear some people yell. Well, functional programmers decided to solve it differently, by having the compiler do the rewriting for us. This only works if our recursive function is what we call "tail recursive": the last statement to evaluate has to be the recursive call. You can rewrite most recursive functions to a tail recursive variant by adding an extra "accumulator" function argument.

Let's make some tailrecursive variants of our fac functions.

In [7]:
let facTail n =
    let rec fac acc n = 
        match n with
        | 0UL -> acc
        | _ -> (fac (n * acc) (n - 1UL))
    fac 1UL n

module BigInteger = 
    let facTail (n : bigint) =
        let rec fac acc (n : bigint) =
            match (n |> BigInteger.compare (BigInteger.Zero)) with
            | Less -> failwithf "negative input: %A" n
            | Equal -> acc
            | Greater -> (fac (n * acc) (n - BigInteger.One))
        fac BigInteger.One n 

In [8]:
//But tailrecursion handles this just fine:
8_000 |> BigInteger.fromInt |> BigInteger.facTail

That's cool. The accumulator trick doesn't always work though. Let's take a look at a more complex recursive function, fib:

In [9]:
let rec fib n = 
    match n with
    | 1 -> 1
    | 2 -> 1
    | _ -> (fib (n - 1)) + (fib (n - 2))

In [10]:
fib 40 //Takes almost a second!

This takes almost a second to calculate! In order to understand why, let's see what the recursive calls look like:



fib 42 = fib 41 + fib 40

fib 41 = fib 40 + fib 39

fib 40 = fib 39 + fib 38

fib 39 = fib 38 + fib 37

fib 38 = fib 37 + fib 36

In order to calculate *fib 42*, we need the results of *fib 41* and *fib 40*. In order to calculate *fib 41*, we **also** need the value of *fib 40* but we just re-do the entire calculation from scratch. This gets very bad very quickly. So, what would help? Keeping track of intermediate results a.k.a. *memoization*!

In [11]:
open System.Collections.Generic

let fibMemo n = 
    let d = new Dictionary<int,int>()
    let rec f n =
        match d.TryGetValue(n) with
        | true, result -> result
        | false, _ ->
            let result = 
                match n with
                | 1 -> 1
                | 2 -> 1
                | _ -> (f (n - 1)) + (f (n - 2))
            d.Add(n, result) //Mutable reference, we're actually updating the dictionary here!
            result
    f n

In [12]:
fibMemo 4_000

Cool, that's way faster. But hold up, that's not tail recursive anymore.

In [13]:
//fibMemo 20_000

(*Here we go again:

Stack overflow.
   at System.Collections.Generic.Dictionary`2[[System.Int32, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e],[System.Int32, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]].FindValue(Int32)
   at System.Collections.Generic.Dictionary`2[[System.Int32, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e],[System.Int32, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]].TryGetValue(Int32, Int32 ByRef)
   at FSI_0003+f@48.Invoke(Int32)
*)

There's two problems I want to tackle:

1. Can we rewrite fibMemo so it gets rid of the mutable Dictionary in favor of an immutable data type?
2. Can we make fibMemo tail recursive so it does not blow the stack anymore?

## Making fibMemo immutable

Turns out we can create an immutable variant. One approach that always works is by threading and updating the lookup table explicitly through recursive calls.

### Threading the lookup table

In [14]:
let fibThreading n = 
    let rec f d n = //Note how the signature has changed: the function takes an extra argument (the lookup table) and returns a richer type: the result along with an updated lookup table. 
        match d |> Map.tryFind n with
        | Some result -> result,d
        | None ->
            match n with
            | 1 -> 1,d |> Map.add 1 1
            | 2 -> 1,d |> Map.add 2 1
            | _ -> 
                let (n1,d1) = f d (n - 1)
                let (n2, d2) = f d1 (n - 2) //we thread the memo from n-1 into the n - 2 call
                let result = n1 + n2
                let nd = d2 |> Map.add n result //Instead of mutating the map, we create a new and updated version of it!
                result, nd
    let (result, _) = f Map.empty n
    result

In [15]:
fibThreading 4_000

### Sidebar: Extracting the memoize using a Y-combinator approach

Can we make a higher-order "memoize" function that can memoize any recursive function? You can somewhat, but we still need to tweak the original function a bit.

* [Combine memoization and tail recursion, stackoverflow, user kvb](https://stackoverflow.com/questions/3459422/combine-memoization-and-tail-recursion/3459864#3459864)
* [The Y combinator, mvanier](https://mvanier.livejournal.com/2897.html)
* [Y in practical programs, Bruce McAdam](https://blog.klipse.tech/assets/y-in-practical-programs.pdf)

In [16]:
open System.Collections.Generic
let memoY f =
  let cache = Dictionary<_,_>()
  let rec fn x =
    match cache.TryGetValue(x) with
    | true,y -> y
    | _ -> let v = f fn x
           cache.Add(x,v)
           v
  fn

let fib_ fib n = //note that this is not a recursive function but rather takes an input argument fib that substitutes for the recursive call.
    match n with
    | 1 -> 1
    | 2 -> 1
    | _ -> (fib (n-1)) + (fib (n-2))

let fib = memoY fib_

fib 2000

While this works and doesn't seem to impact performance that much, all versions of our memoized function still blow the stack:

In [17]:
//fibThreading 20_000
(*
Stack overflow.
Repeat 6852 times:
--------------------------------
   at FSI_0003+f@62-1.Invoke(Microsoft.FSharp.Collections.FSharpMap`2<Int32,Int32>, Int32)
*)

So can we make our memoized functions tail recursive as well?
## CPS as a way to rewrite anything to tail recursion

Continuation passing style can make anything (?) tail recursive!

[bouncing around with recursion](https://johnazariah.github.io/2020/12/07/bouncing-around-with-recursion.html)

In [18]:
let rec fib n c =
    match n with
    | 1 -> c 1
    | 2 -> c 1
    | _ -> 
        fib (n-1) (fun n1 -> fib (n-2) (fun n2 -> c (n1 + n2))) //note how the last thing we do is call the recursive function. "Next steps" get delayed to the continuation functions. This effectively moves our calculation from the stack to the (much bigger) heap.

fib 40 id

Let's apply this CPS technique to our fully immutable memoized version!

In [19]:
type Ordering = Less | Equal | Greater
module BigInteger =

    open System.Numerics

    let fromInt (i : int) = BigInteger i
    let Zero = fromInt 0
    let One = fromInt 1

let fibThreadingTail n = 
    let rec f d n c =
        match d |> Map.tryFind n with
        | Some result -> (result,d) |> c
        | None ->
            match n with
            | 1 -> (BigInteger.One,d |> Map.add 1 BigInteger.One) |> c
            | 2 -> (BigInteger.One,d |> Map.add 2 BigInteger.One) |> c
            | _ -> 
                f d (n - 1) (fun (n1,d1) ->
                    f d1 (n - 2) (fun (n2,d2) ->
                        let result = n1 + n2
                        let nd = d2 |> Map.add n result
                        (result, nd) |> c))
                
    let (result, _) = f Map.empty n id
    result

Well. There you have it. A memoized fib function that uses tail recursion. It's fast. It works for big inputs without blowing the stack. It does appear to be working, if you've ever wondered what the 100.000th number in the Fibonacci sequence looks like, here you go:

In [20]:
fibThreadingTail 100_000

## Closing thoughts

Can you rewrite any function that both *memoizes* intermediate results using an immutable data structure and is *tail recursive*? You sure can! There's even a mechanical way to do it: 

1. In order to speed up a recursive function, we can use memoization. We can even extract a generic memoizer higher-order function if we really really want to, but in order to to that we have to tweak the original function so it accepts a "recursive function" argument. We saw this with our Y combinator approach.
2. In order to memoize any function using an immutable data structure, we can explicitly thread the lookup table through subsequent recursive calls. This requires even more tinkering with the original function.
3. In order to make any function tail recursive, you can rewrite it using Continuation Passing Style (CPS).

But should you? That's a final question I leave you with which you can ponder to yourself. I might be grabbing for CPS next time I blow the stack, but I think I'll stick to using mutable lookup tables for now.