# Continuations and Continuation Passing Style

    * Every function will have an extra argument called the 'continuation' 
    * its a function that is passed in and specifies what the caller wishers to do with the result that the current frame computes.
  
```
Take for instance, a function func that takes in an intenger and returns an integer

    def func(x:Int): Int = {
        // .. do  some work to compute the result ..
    }
    
    In the CPS, it written as 
    
    def func-k(x:Int, k:Int => Int): Int = {
        // ..do some work to compute the result ..
        k(result)  // pass the result onto the continuation
```

In [1]:
def addUp(x: Int, y: Int, z: Int): Int = {
    x + y + z 
}

def multiply(x: Int, y: Int): Int = {
    x * y
}

def madd(x: Int, y: Int, z: Int): Int = {
    val v1 = multiply(x,y)
    val v2 = addUp(v1,y,z)
    return v2
}

println(madd(1,2,3))

7


defined [32mfunction[39m [36maddUp[39m
defined [32mfunction[39m [36mmultiply[39m
defined [32mfunction[39m [36mmadd[39m

__Now we are going to create the CPS version of these functions__

In [3]:
def add_k(x:Int, y:Int, z:Int, k: Int => Int):Int = {
    val v1 = x + y + z  // compute the result for the frame
    k(v1)   // pass it onto the continuation 
}
def multiply_k(x:Int, y:Int, k:Int => Int):Int ={
    val result = x*y // compute the result of the frame
    k(result)  // pass it onto the continuation
}

/* function gets interesting here because it isn't a one step
    function like the previous two.
    
   original order of work done in the function 
       multiply(x,y)
       addUp(v1,y,z)
       
    remember that the generally, you complete one thing at a time
    and you just pass on the rest of the work so that any stack
    frame at any given time ISNT waiting and thus not being tail
    recursive.
*/
def madd_k(x:Int, y:Int, z:Int, k:Int => Int): Int = {
    /* what you are going to do is form the contination
    for the next function call that needs to happen after
    'multiply_k' */
    
    def remainingWork(result:Int):Int = {
        add_k(result,y,z,k)
    }
    
    /*now that you created the next step for the next fram
        do the first step in the process
    */
    multiply_k(x,y,remainingWork)
}

println(madd_k(1,2,3,x=>x))


7


defined [32mfunction[39m [36madd_k[39m
defined [32mfunction[39m [36mmultiply_k[39m
defined [32mfunction[39m [36mmadd_k[39m

__Thing to notice for what we just did__

    * The translation for addUp, multipl were straightfoward. These functions just got a new argument, k, for the contination. They simply compute what they did originally and instead of retunring the result, they simply call k on it. 
    
```    
    * The tricky function was madd_k, what did madd it do? 
        * called addUp on  on x,y
        * took the result and called the mult function
      * Thus we can write down what madd_k should do. 
          
             call multiply_k, on x,y, and pass a continuation k to multiply+k. What must this continuaton do? 
             
             the continaution, k, should do what madd originally idd. Take the result of multiply_k and call add_k.
             
     * The continuationm k, should do the arrears work madd would have done after the call to multiply returned. 
         1. call addUp_k
         2. pass the result on to k.
```
     
__another example__     

In [18]:
def f1(x: Int): Int = {
    if (x <= 0){
        1
    }else{
        3 + f1(x-10)
    }
    /*
        > x - 10 
        > f(result1)
        > result1 + 3
    */
}
println(f1(14))

def f1_k(x:Int, k:Int => Int): Int = {
    if (x <= 0){
        k(1)
    }else{
        val resultOfthisFrame = x-10 
        /* create the continuation so the next function call
        knows what to do */ 
        def k1(resultofNextFrame:Int): Int = {
            k(resultofNextFrame + 3)
        }
        f1_k(resultOfthisFrame, k1)
    }
}
println(f1_k(14,x=>x))

7
7


defined [32mfunction[39m [36mf1[39m
defined [32mfunction[39m [36mf1_k[39m

# Start the from the ground, write the very basic forms of CPS

In [6]:
def addNums(x: Int, y: Int): Int = {
    x + y
}

def subNums(x: Int, y: Int): Int = {
    val result = x - y
    result
}

//cps version
def addNums_k(x:Int, y: Int, k: Int => Int):Int ={
    val result = x + y
    k(result) 
}
//cps version
def subNums_k(x: Int, y:Int, k: Int => Int): Int= {
    val result = x - y 
    k(result) // give the result for the function to handle
}


// println(subNums(10,4))
// println(subNums_k(10,4,x=>x))
// println(addNums(2,-30))
// println(addNums_k(2,-30,x=>x))



/* now we will make it interesting by making a function that 
does two things */ 

/* function adds x,y then subtracts that result from z */
def addSub(x:Int, y:Int, z:Int):Int = {
    val sum = addNums(x,y) 
    val diff = subNums(sum,z)
    diff
}

/* calculate the result of the frame, then just pass 
along the work to a continuation you design 
    addNums_k first
    then subNums_k
*/ 
def addSub_k(x:Int, y: Int, z:Int, k: Int => Int): Int ={
    /* creating the continuation */
                        // z and k are known in the closure
    
    // "k is a function that is just waiting for result of next function
    def k1(v1:Int): Int = subNums_k(v1, z, k)
    addNums_k(x,y,k1) // i will return whatever this returns
}


println(addSub(1,22,84))
println(addSub_k(1,22,84,x => x))

-61
-61


defined [32mfunction[39m [36maddNums[39m
defined [32mfunction[39m [36msubNums[39m
defined [32mfunction[39m [36maddNums_k[39m
defined [32mfunction[39m [36msubNums_k[39m
defined [32mfunction[39m [36maddSub[39m
defined [32mfunction[39m [36maddSub_k[39m

# three step continuation

In [9]:
def addNums(x: Int, y: Int): Int = {
    x + y
}

def subNums(x: Int, y: Int): Int = {
    val result = x - y
    result
}

def multNums(x: Int, p: Int): Int = {
    val result = x * p 
    result
}

def addSubMult(x:Int, y:Int, z:Int, p: Int):Int = {
    val sum = addNums(x,y) 
    val diff = subNums(sum,z)
    val prod = multNums(diff, p)
    prod
}


println(addSubMult(-1,3,4,10))






//cps version
def addNums_k(x:Int, y: Int, k: Int => Int):Int ={
    val result = x + y
    k(result) 
}
//cps version
def subNums_k(x: Int, y:Int, k: Int => Int): Int= {
    val result = x - y 
    k(result) // give the result for the function to handle
}
// cps version, "k is a function that is waiting for the result 
def multNums_k(x:Int, p:Int, k: Int => Int): Int = {
    val result = x * p 
    k(result) 
}

/* in this case you just build up the process in this call, then 
it all will execute */
def addSubMult_k(x:Int, y: Int, z:Int, p:Int, k: Int => Int): Int ={
    /* creating the continuation */
    /* the inner continuation holds what needs to be done next */
    def knext(v1:Int): Int = multNums_k(v1,p,k)
    def k1(v1:Int): Int = subNums_k(v1, z, knext)
    addNums_k(x,y,k1) // i will return whatever this returns
}

println(addSubMult_k(-1,3,4,10, x=>x))

-20
-20


defined [32mfunction[39m [36maddNums[39m
defined [32mfunction[39m [36msubNums[39m
defined [32mfunction[39m [36mmultNums[39m
defined [32mfunction[39m [36maddSubMult[39m
defined [32mfunction[39m [36maddNums_k[39m
defined [32mfunction[39m [36msubNums_k[39m
defined [32mfunction[39m [36mmultNums_k[39m
defined [32mfunction[39m [36maddSubMult_k[39m

# now you are going to try to write cps of a recursive function on itself

In [2]:
def increment(x:Int, max:Int):Int = {
    if (x == max){
        println(x)
        max
    }else{
        println(x)
        increment(x+1, max)
    }
}

println(increment(0,10))

0
1
2
3
4
5
6
7
8
9
10
10


defined [32mfunction[39m [36mincrement[39m

In [13]:
/* Stack isn't growing because each frame isn't waiting on 
something else to do */

/* essentially, the basis of this function is to build up the 
    continuation function so that when you hit the base case
     you just give it the intial value 
*/
def increment_k(x:Int, max:Int, k:Int => Int): Int = {
    if (x == max){
        // zero because k incrementer is in the form
        // (0+1)+1)+1)+1
        k(0) 
    }else{
        println(x+1)
        def k1(v1:Int):Int = k(v1+1) // building up the continuation 
        increment_k(x+1,max, k1)
        
    }
}

println(increment_k(0,10,x => x))

1
2
3
4
5
6
7
8
9
10
10


defined [32mfunction[39m [36mincrement_k[39m

# Working through the examples provided in notes

In [11]:
// Convert this into cps 
def f1(x:Int): Int = {
    if (x <= 0){
        1
    }else{
        3 + f1(x - 10)
    }
}
println(f1(25))


// cps version 
def f1_k(x:Int, k:Int => Int): Int ={
    if (x <= 0){
        k(1)
    }else{
        /* create the continaution for the next call */ 
        def k1(v1:Int):Int = k(v1 + 3) // THIS IS AN ACCUMULATOR
        f1_k(x-10,k1)
        
    }
}
println(f1_k(25, x=>x))

10
10


defined [32mfunction[39m [36mf1[39m
defined [32mfunction[39m [36mf1_k[39m

# Try converting the factorial function 

In [21]:
def factorial(x: Int): Int = {
    if (x <= 1){
        1
    }else{
        x * factorial(x-1)
    }
}

// cps version
// this is another case of building up the contination with 
// (1*2)*3)*..)*..n)
def factorial_k(x: Int, k: Int => Int): Int ={
    if (x <= 1){
        k(1)
    }else{
        println(x)
        def k1(v1:Int): Int = k(v1 * x) // building up the continuation
        factorial_k(x-1, k1)
    }
}
println("CPS: ", factorial_k(8, x=>x))
println("NON CPS:", factorial(8))

8
7
6
5
4
3
2
(CPS: ,40320)
(NON CPS:,40320)


defined [32mfunction[39m [36mfactorial[39m
defined [32mfunction[39m [36mfactorial_k[39m

# KEY ASPECT

create an accumulator by just building up the contination!!!!

# Harder continuation, fibonacci.
## This will go over handling multiple function calls in a function

In [30]:
def fibonacci(x: Int): Int  = {
    if (x <= 1){
        1
    }
    else{
        fibonacci(x-1) + fibonacci(x-2)
    }
}

//cps version, this is going to be a accumulated sum 
def fibonacci_k(x: Int,k: Int => Int): Int = {
    if (x <= 1){
    /*this will always be the intial value that 
    fits the built up continuation function when 
    it gets here*/
        k(1)
    }
    else{
        /* create the continuation 
            have to handle two returning results without
             making it non tail recursive...
        */ 
        
        /* this is what i do with the result of this function call */
        def k1(v1:Int):Int = { 
         /* this is what i do with the with the second result i calculate */   
            def k2(v2:Int): Int = {
                k(v1 + v2)  // always wrap up with k when accunulating?
            }
         /* with the result i just got, calculate the n-2 for it */
            fibonacci_k(x-2, k2)
        }
        
        fibonacci_k(x-1, k1)
    }
}
// does only the inner most get wrapped up in k()?


println(fibonacci_k(4,x=>x))
println(fibonacci(4))

// this shit is still hard as fuck to trace

5
5


defined [32mfunction[39m [36mfibonacci[39m
defined [32mfunction[39m [36mfibonacci_k[39m

# easier functions to convert

In [31]:
def simple_fun(x: Int): Int = {
    val y = x * x
    val z = y + y - 5 * x
    if(z <= 0){
        1
    }
    else{
        z
    }
}

// cps version 
/* nothing much changes here because
    there is nothing for the function to do 
     after it handles its calculations other than return
*/ 
def simple_fun_k(x: Int, k: Int => Int): Int = {
    val y = x * x
    val z = y + y - 5 * x
    if(z <= 0){
        k(1)
    }
    else{
        k(z)
    }
}

println(simple_fun(10))
println(simple_fun_k(10,x=>x))

150
150


defined [32mfunction[39m [36msimple_fun[39m
defined [32mfunction[39m [36msimple_fun_k[39m

# when theres a function call in each branch

In [None]:
// this is tail recursive but we want it to be in CPS
def tail_call_fun(x: Int): Int = {
    if ( x >= 0){
        simple_fun(x + 1)
    }else{
        val y = x * x - 2
        simple_fun(y)
    }
}

// this thing has no base case.
def tail_call_fun_k(x: Int, k: Int => Int): Int = {
    if ( x >= 0){
        simple_fun_k(x + 1, k)
    }else{
        val y = x * x - 2
        simple_fun_k(y, k)
    }
}

# more complicated 

In [31]:
def fancy_function(x: Int, y: Int): Int = {
    if (x == 0)
        return 0
    else if (x > 0) {
        val s1 = 25
        val y1 = x * y + x - y
        s1 + y1
    } else {
        val y1 = tail_call_fun(x)
        y1 + y - 2 * x
    }
    
}
def fancy_function_k(x: Int, y: Int, k: Int => Int): Int = {
    if (x == 0)
        return k(0)
    else if (x > 0) {
        val s1 = 25
        val y1 = x * y + x - y
        k(s1 + y1)
    } else {
        // Transform code after call 
        //  y1 + y - 2 * x
        def k1(y1: Int): Int = {
            k(y1 + y - 2 * x)
        }
        tail_call_fun_k(x,  k1)
    }
    
}

cmd31.sc:9: not found: value tail_call_fun
        val y1 = tail_call_fun(x)
                 ^cmd31.sc:27: not found: value tail_call_fun_k
        tail_call_fun_k(x,  k1)
        ^cmd31.sc:27: missing argument list for method k1
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `k1 _` or `k1(_)` instead of `k1`.
        tail_call_fun_k(x,  k1)
                            ^Compilation Failed

: 