# Continuations
__RECAP of Tail Recursion__
Converting non-tail recursion into tail
    * value of writing tail recursion functions keeps the stack from growing. 
    * result of the call is is called back to the callee without the need for further processesing 
    
    
// Non tail recursive function    
```
def factorial(x: Int): Int = {
    if ( x<= 1 ) 
        1
    else{
        x * factorial(x-1)}
```        
        
// Tail version using accumulator 
```
def factorailTail(x: Int, acc: Int): Int = {
    if x <= 1{
        acc}
    else{
        factorialTail(x-1, acc * x)    // * is a commutative operator 
}
```


Doing that type of conversion doesn't always work, consider the eval function

```
    def eval(e: Expr, env: ...): Value = {
        eval(...) // eval on newExpr
    } // tail rec
```
This is a complicated to convert into tail recursion. We will make any function recursive by rewriting these types of functions through __continuation passing style__

__Every function we write now will be like this. __

```
    Every function should be written this way from now on. 
    What is a continauation? It is just an extra function that will be called next on the result of the current function
    
1. force an extra parameter  called the continuation , it will be a function
    def foo(x: Int): Int = {                def foo_k(x: Int, k: Int => Int): Int = {
    
                        
    }                                       }
    
    def bar(y: String = {                   def bar_k(y: String, k: String => String): String = {
    
   
    }                                       }
```

### Concrete Example (factorial) 

__Non Tail Recrusive Version__                       __CPS Version (tail-Rec version) 
```
def fact(x: Int): Int = {                               def fact_k(x: Int, k: Int => Int): Int = {
    if (x <= 1){                                              if (x <= 1){
        1                                                         k(1)
    else{                                                     else{
        val v = fact(x - 1)                                       fact(x-1, v => {k(v*x)} }
        v * x
    }
}    
```

What is the main idea here? Originally factorial took in __X__ and gave you a result. 
The new factorial will have the following shape
```
        Takes in some x ,  and some value k
            fact_k will never return directly to the user. IT WILL ALWAYS pass it to the continuation function (k)
          Whatevert eh continuation returns is what is given back to the user. 
        
        Always , pass the result to the continuation function
```


Continuation is specifing what is left to do in the next call then return it. So at heart is still recursion, but we are just packaging up leftover work for the subsequent calls to handle. __The important part is that we return a function call to k once we come back from a return __

## Recipe for continuations
```
    1. Add a extra parameter that takes in the result type and returns another function 
           def fn_cps(..., k: Tn => Tn) : Tn             , Tn is some generic return type
         It is just important to note that the input of the coninuation must be the same as what the function returns
         
      This is how you want continuations to look 
      
    def foo_cps(....,k) = {
        ... do what foo did before this continuation...
        return k(originalRetValue) <-- pass the return through a continuation (normally a identity function x=>x)   
    }    
```

__Ex:1 simple example where the continuation function is just, K: X => X__

```                                                  
                                                        input type for k should be same as output. 
    def foo(x: String): Int = {                def foo_cps(x: String, k: Int => Int): Int = {
        if (x.length >= 10){                              if x.length >= 10
            x.length-10                                       k(x.length-10)
        }   
        else{                                             else
            25                                                k(25)
        }
    }
    
   
```

__Ex: 2 , a little more complicated __ 
```
    def bar(x: String): String = {                  def bar_cps(x: String, k: String => String{
         if (x.length <= 0){                              if x.length <= 10
             "hello"                                          k("hello")
         }                                                else
         else{                                                substring_cps(x, 0, x.length-2, k)
             substring(x, 0, x.length-2) 
         }
    }
```

Example 2 explained. 
```
   __The goal of CPS__: The call of every function call must be a tail call. 
           def complicated(f1: t1, p2: t2 .. xn: tn): Tn = {
                  ....
                  otherComp(...)   <-- these calls are non tails
                  ....
                  otherComp(...)   <-- there is still more things to do 
                  ...
           }
                   if k has output type t then the whole code has output type t
           def complicated_CPS(f1: t1, p2: t2 .. xn: tn, k: Tr => T): T = {
                  ....
                  otherComp_CPS(..., (newContinuation)
           }

What is the new continuation? 
    1. It is THE REST of the code / computations that happen after the 'OtherComp' returns in the old version of the function. 
    2. Pass the orginal return through k: Continuation
```             


__Ex 3: the fibonacci example__ 
```
    def fibonacci(i: Int): Int = {
        if (i <= 2)
            1
        else
            fibonacci(i-1) + fibonacci(i-2)
    }
    
    // T is called a generic and for each version of T then it will create a version for that input type. 
    def fibonacci_cps[T](i: Int, k: Int => T): T = {
        if (i <= 2)
            k(1)
        else
            fibonacci_cps(i-1, v1 = { fibonacci_cps(i-2, v2 => { k(v1+v2) } }
    }
    
    // we have essentially packed up the remaining computation as an anonymous function and created a new continuation
```

# Coded Examples

## Factorial Function

In [6]:
/* Our goal is to convert this function into a CPS function so that it is tail recursive */
def factorial(x: Int): Int = {
    if (x <= 1){
        1
    }
    else{
        // x * factorial(x - 1),  spill the function 
        val v1 = factorial(x-1)
        return x * v1
    }
}
/* function keeps the initial parameter then adds a function that takes an int and returns an int */
def factorial_k(x: Int, k: Int => Int): Int = {
    /* the base case essentially stays the same throughout */ 
    if (x <= 1){
        k(1)    // all you have to do is call k(<base>) 
    }
    else{ 
        /* get interesting here because we essentially complete one step then pass the rest of the work along */
        /* The new continuation function will further along what needs to be completed and will return 
            a function call back to this functions k
        */
        factorial_k( x-1, v => {k(x*v)})
    }
}

/* now we will create that same function but we will use generic functions to define k return type */ 
// [T] , this is just a temporary placeholder the type we will be using 
/* 
TYPE Variables
factorial_g is a generic function that has a type variable T
*/
def factorial_g[T](x: Int, k: Int => T): T = {
    if (x<=1){
        k(1)
    }
    else{
        /* "I want you to take wheveter v1 is in the next function call is and then call my K function and return
            another function call to K for the function before me, but with the processed value"
        */
        factorial_g( x-1, v => {k(x*v)})
    }
}


defined [32mfunction[39m [36mfactorial[39m
defined [32mfunction[39m [36mfactorial_k[39m
defined [32mfunction[39m [36mfactorial_g[39m
[36mres5_3[39m: [32mInt[39m = [32m6[39m

In [8]:
val v1 = factorial_g(10, v=>v) // identity continuation, just return the computation, nothing else
factorial_g(10, v => println(v))  // make the continuous function print it

3628800


[36mv1[39m: [32mInt[39m = [32m3628800[39m

# The Fibonacci Problem 
## This is where it gets interesting because we have to decompose the function 

In [16]:
import scala.annotation.tailrec
def fibonacci(n: Int): Int = {
    if (n <= 2){
        1
    }else{
        // fibonacci(n-1) + fibonacci(n-2)
        // spill it always
        val v1 = fibonacci(n-1)  // going to complete this only
        val v2 = fibonacci(n-2)  // everything below here are the computations that still need to be performed
        v1 + v2
    }
}

def fibonacci_k[T](n: Int, k: Int => T): T = {
    if (n <= 2){
        k(1)
    }
    /* This is where it gets interesting because we need to make sure this is tail recursive */
    else{
        // scala doesn't actually accept this as a tail call, and refuses to do tail transformation
        fibonacci_k(n-1, {(v1: Int) => fibonacci_k(n-2, (v2: Int) => { k(v1 + v2)})} )
        
        /*  THIS IS THE DIRTY WAY
        // Remaining Computation to be performed for fibonacci (n-1)
        def remainingComputationToBePerformed(v1:Int) = {
            // Remaining computation for fibonacci(n-2)
            def remainingRemainingComp(v2: Int) = {
                k(v1+v2)
            }
            fibonacci_k(n-2, remainingRemainingComp)
        }
        fibonacci_k(n-1, remainingComputationToBePerformed )
        */
    }
    
}

fibonacci_k(15, x=>x)

[32mimport [39m[36mscala.annotation.tailrec
[39m
defined [32mfunction[39m [36mfibonacci[39m
defined [32mfunction[39m [36mfibonacci_k[39m
[36mres15_3[39m: [32mInt[39m = [32m610[39m

# Lecture Notes


# Continuation Passing Style


In [18]:
// CPS: removing non-tail function calls and converting it into tail calls
// Why? Stack efficiency, no overflows 

def factorial(x: Int): Int = {
    if (x <= 1){
        1
    }
    else{
        x * factorial(x-1)
    }
}

def factorialTail(x: Int, acc:Int){
    if (x <= 1){
        acc
    }
    else{
        factorialTail(x-1, acc * x)
    }
}

defined [32mfunction[39m [36mfactorial[39m
defined [32mfunction[39m [36mfactorialTail[39m

In [None]:
//def eval(e: Expr, env: Map[String, Value]): Value = {
    // .. has a lot of non tail recursive calls back to eval. 
    // this is a problem
    // we want to do eval with just tail recursive calls 
//}

In [22]:
// Continuation passing style (CPS) 
// CPS transformation

/* factorial cannot return a value it will pass its return value through k
    Any return value must be passed through the continuation and
     whatever the continuation returns must be the final return value in that function frame. 
*/
def factorial_cps(x: Int, k: Int => Int): Int = {
    println(s"Hello: You called factorial_cps on $x")
    if (x <= 1){
        k(1)  // TAIL CALL
    }
    else{
        //x* factorial(x-1), spill the statements so computations are clear. 
        // val v = factorial(x-1)
        // v*x 
        // let me package up all the computation that needs to be done 
        def allComputationThatHappensOnTheReturnValue(v: Int) = {
            println(s"Hello: I am continuation that multiplies by $x")
            k(v * x)  // packaging up the computatoin that is left and I am passing it along. 
        }
        
        /* now given all the work that needs to be done to the next item */
        // My left over work is to 
        // 1. calculate v * x
        // 2. pass it to to "the left over work" of my caller
        // the whole thing after 2 returns is my return value. 
        factorial_cps(x-1,allComputationThatHappensOnTheReturnValue)  // TAIL CALL
    }
}

factorial_cps(3, x => x)


Hello: You called factorial_cps on 3
Hello: You called factorial_cps on 2
Hello: You called factorial_cps on 1
Hello: I am continuation that multiplies by 2
Hello: I am continuation that multiplies by 3


defined [32mfunction[39m [36mfactorial_cps[39m
[36mres21_1[39m: [32mInt[39m = [32m6[39m

In [24]:
def fibonacci(x: Int): Int = {
    if ( x <= 2){
        1
    }
    else{
        //fibonacci(x-1) + fibonacci(x-2),  spill it
        val v1 = fibonacci(x-1)
        val v2 = fibonacci(x-2)
        v1 + v2
    }
}

defined [32mfunction[39m [36mfibonacci[39m

In [25]:
/* This is one way to do it */ 
def fibonacci_cps(x:Int, k: Int => Int): Int = {    // 1. get the signature for fibonacci_cps
    if ( x <= 2){
        k(1)
    }
    else{
        def leftOverWorkForFibbyXMinusOne(v1: Int) = {  // result for fibonacci(x-1), so do the rest inside that remains
            //val v2 = fibonacci(x-2),  yes this is the remaining work, but you need to also make it a tail call. 
            // write a continuation for it
            def leftOverWorkForFibbyXMinusTwo(v2: Int) = {
                k(v1 + v2)
            }
            fibonacci_cps(x-2, leftOverWorkForFibbyXMinusTwo)  // actually making the function execute the leftover work
        }
        fibonacci_cps(x-1, leftOverWorkForFibbyXMinusOne)
    }
}

defined [32mfunction[39m [36mfibonacci_cps[39m

In [26]:
/* this is a better cleaner way to do it */ 
def fibonacci_cps_g[T](x:Int, k: Int => T): T = {    // 1. get the signature for fibonacci_cps
    if ( x <= 2){
        k(1)
    }
    else{
        fibonacci_cps(x-1, (v1) => {
            fibonacci_cps(x-2, (v2) => {k(v1 + v2)})
        })
    }
}

defined [32mfunction[39m [36mfibonacci_cps[39m

In [28]:
def foo_cps[T](x: String, k:Int => T): T = {
    k(x.toInt)
}

def bar_cps[T](x: Int, k: Int => T): T = {
    foo_cps((x+1).toString, k)
}

defined [32mfunction[39m [36mfoo_cps[39m
defined [32mfunction[39m [36mbar_cps[39m

In [30]:
def factorial_cpsG[T](x: Int, k: Int => T): T = {
     println(s"Hello: You called factorial_cps on $x")
    if (x <= 1){
        k(1)  // TAIL CALL
    }
    else{
        //x* factorial(x-1), spill the statements so computations are clear. 
        // val v = factorial(x-1)
        // v*x 
        // let me package up all the computation that needs to be done 
        def allComputationThatHappensOnTheReturnValue(v: Int) = {
            println(s"Hello: I am continuation that multiplies by $x")
            k(v * x)  // packaging up the computatoin that is left and I am passing it along. 
        }
        
        /* now given all the work that needs to be done to the next item */
        // My left over work is to 
        // 1. calculate v * x
        // 2. pass it to to "the left over work" of my caller
        // the whole thing after 2 returns is my return value. 
        factorial_cps(x-1,allComputationThatHappensOnTheReturnValue)  // TAIL CALL
    }
}
def terminal_continuation(v: Int): String = v.toString
factorial_cpsG(8,terminal_continuation)

Hello: You called factorial_cps on 8
Hello: You called factorial_cps on 7
Hello: You called factorial_cps on 6
Hello: You called factorial_cps on 5
Hello: You called factorial_cps on 4
Hello: You called factorial_cps on 3
Hello: You called factorial_cps on 2
Hello: You called factorial_cps on 1
Hello: I am continuation that multiplies by 2
Hello: I am continuation that multiplies by 3
Hello: I am continuation that multiplies by 4
Hello: I am continuation that multiplies by 5
Hello: I am continuation that multiplies by 6
Hello: I am continuation that multiplies by 7
Hello: I am continuation that multiplies by 8


defined [32mfunction[39m [36mfactorial_cpsG[39m
defined [32mfunction[39m [36mterminal_continuation[39m
[36mres29_2[39m: [32mString[39m = [32m"40320"[39m

In [32]:
// convert this into a cps function 
def foo(x: Int): Int = {
    if ( x <= 5){
        3
    }
    else{
        //foo( foo(x-1) - 3)
        val v1 = foo(x-1)  // we only attack this line, we bundle up the rest
        val v2 = foo(v1 - 3) // bundle this up 
        v2
    }
}
def foo_cps[T](x: Int, k:Int => T): T = {
    if ( x <= 5){
        k(3)
    }
    /* now we bundle up remaining work and pass it on to the continuation function */ 
    else{
        foo_cps(x-1, v1 => {
            foo_cps(v1 - 3, (v2) => {
                k(v2)
            })
        })
    }
}
foo(10)
foo_cps(10, x => x)

defined [32mfunction[39m [36mfoo[39m
defined [32mfunction[39m [36mfoo_cps[39m
[36mres31_2[39m: [32mInt[39m = [32m3[39m
[36mres31_3[39m: [32mInt[39m = [32m3[39m