# Topic 5. Recursive functions and data types

The goals of this topic are to understand:

* How recursive types (lists, trees, etc.) are defined algebraically
* How functions over recursive types are defined recursivelly
* The two major types of recursive functions: general and tail-recursive

### References

[__Programming in Scala, 
A comprehensive step-by-step guide__](https://www.artima.com/shop/programming_in_scala_3ed) Third Edition.
by Martin Odersky, Lex Spoon, and Bill Venners. 

- Chapter 16. Working with Lists
- Chapter 26. Extractors (optional)

[__Functional programming in Scala__](https://www.manning.com/books/functional-programming-in-scala), by Paul Chiusano and Runar Bjarnason.

- Chapter 3. Functional data structures

[__Functional programming simplified__](https://alvinalexander.com/downloads/fpsimplified-free-preview.pdf), by Alvin Alexander.

- Chapters 29-36. Recursion.

## Recursive types

### The `List` type

Lists are data structures which represent sequences of values of the same type, of finite length. They can be defined recursively in an informal way as follows: 
- A list is the empty sequence
- A list is a non-empty sequence made of a value and another list, which represent the head and tail of the list, respectively

Thus, the type `IntList`, which represents lists of integers, must satisfy the following algebraic equation:

`IntList = 1 + Int * IntList`

i.e., a list of integers is the empty sequence (represented by the singleton type `1`), or an integer (the head) and a list (its tail).



In [None]:
// IntList = 1 + Int + Int * Int + Int * Int * Int + ....
// IntList = 1 + Int * (1 + Int + Int * Int + ...)
// IntList = 1 + Int * IntList

In [None]:
enum IntList: 
    case Empty()
    case NonEmpty(head: Int, tail: IntList)

In [None]:
object AdHoc: 

    enum List[A]: 
        case Empty[A]() extends List[A]
        case NonEmpty[A](head: A, tail: List[A]) extends List[A]

In [None]:
import IntList._

In [None]:
// [1,2,3]
NonEmpty(1, NonEmpty(2, NonEmpty(3, Empty()))) 

The implementation in Scala is similar to the following one (we also give the generic version `List[A]`, rather than the implementation of `IntList`):

In [None]:
object StdDefinition:
    enum IntList: 
        case Nil
        case ::(h: Int, t: IntList)

    import IntList._

    val l: IntList = 
        ::(1, ::(2, ::(3, Nil)))

    val l2: IntList = 
        1 :: (2 :: (3 :: IntList.Nil))

In [None]:
val l: List[Int] = 1 :: (2 :: (3 :: Nil))

In [None]:
val l: List[Int] = ::(1, ::(2, ::(3, Nil)))

In [None]:
val l: List[Int] = List.apply(1,2,3)
val l2: List[Int] = List(1,2,3)

However the actual implementation of [immutable lists](https://github.com/scala/scala/blob/v2.13.1/src/library/scala/collection/immutable/List.scala#L79) in the standard library of Scala defines the empty list as an object, rather than a class. This forces us to declare the list covariantly in its generic parameter `A`, which is somewhat inconvenient at times.  The standard definition looks like as follows:

In [None]:
object ActualStdDefinition:


### Some syntactic sugar

Note that we can write standard lists with a more compact syntax: 

In [None]:
// Less beautifully 

// More idiomatically


And we can also pattern match on lists, similarly:

In [None]:
object Std: 

    enum Either[A, B]: 
        case Left[A, B](a: A) extends Either[A, B]
        case Right[A, B](b: B) extends Either[A, B]

In [None]:
val e: Either[Int, String] = Left(1)

e match 
    case Left(i) => "un entero"
    case Right(s) => "un string"

In [None]:
// Less beautifully
def isEmpty[A](l: List[A]): Boolean = 
    l match 
        case Nil => true
        case ::(h: A, t: List[A]) => false 

// more idiomatically

def isEmpty2[A](l: List[A]): Boolean = 
    l match 
        case Nil => true
        case h :: t => false 
// or

val l: List[Int] = List(1,2,3)

l match 
    case List(x, y, z) => x+y+z
    case _ => 0



##  Recursive functions

Since lists are defined recursively, functions over lists will be commonly recursive as well. For instance, let's implement a recursive function that computes the length of a list. But before, let's implement the function imperatively for the sake of comparison:

In [None]:
val i: Int = 1


In [None]:
i = 5

In [1]:
// Using mutable variables

def lengthI[A](l: List[A]): Int = 
    var out: Int = 0 
    var aux: List[A] = l 
    while (aux != Nil)
        aux = aux.tail
        out += 1 
    out


defined [32mfunction[39m [36mlengthI[39m

In [None]:
// Using mutable variables

def lengthI[A](l: List[A]): Int = 
    var out: Int = ???
    var aux: List[A] = l 
    while (aux != Nil)
        aux = aux.tail
        out = ???(out, aux.head)
    out


-- [E050] Type Error: cell1.sc:6:14 --------------------------------------------
6 |        out = ???(out, aux.head)
  |              ^^^
  |              method ??? in object Predef does not take parameters
  |
  | longer explanation available when compiling with `-explain`
Compilation Failed

In [None]:
// invoke
lengthI(List(1,2,3,4))
lengthI(List())
lengthI(List(1))

The recursive function is implemented as follows: 

In [None]:
// Using recursive functions

def lengthR[A](l: List[A]): Int = 
    ??? : Int

In [None]:
// Using recursive functions

def lengthR[A](l: List[A]): Int = 
    l match 
        case Nil => ??? : Int
        case h :: (t: List[A]) => 
            val tailSol: Int = lengthR(t)
            ??? : Int 

In [None]:
// Using recursive functions

def lengthR[A, B](l: List[A]): B = 
    l match 
        case Nil => ??? : B
        case h :: (t: List[A]) => 
            val tailSol: B = lengthR(t)
            ??? : B 

In [13]:
// Using recursive functions

@scala.annotation.tailrec
def lengthR[A](l: List[A]): Int = 
    l match 
        case Nil => 0 : Int
        case h :: (t: List[A]) => 
            val tailSol: Int = lengthR(t)
            tailSol + 1 : Int 

-- Error: cell14.sc:6:38 -------------------------------------------------------
6 |            val tailSol: Int = lengthR(t)
  |                               ^^^^^^^^^^
  |                 Cannot rewrite recursive call: it is not in tail position
Compilation Failed

In [4]:
lengthR(Nil)
lengthR(List())
lengthR(List(1,2,3,4))

[36mres4_0[39m: [32mInt[39m = [32m0[39m
[36mres4_1[39m: [32mInt[39m = [32m0[39m
[36mres4_2[39m: [32mInt[39m = [32m4[39m

In [8]:
lengthR(List.fill(8000)(1))

java.lang.StackOverflowError: null

In [None]:
List.fill(10)(1)

In [2]:
lengthI(List.fill(1000000)(1))

[36mres2[39m: [32mInt[39m = [32m1000000[39m

Some comments: 
- The recursive function is implemented in a _type-driven development_ style: we proceed, step-by-step, analysing the types of input data that are available, and the types of output that we have to generate. This leads to a divide-and-conquer problem solving strategy and hugely facilitates the implementation.
- The recursive function is less efficient, since the stack will blow up with very long lists.

### Tail-recursive functions

The implementation using tail-recursion solves the issues with the stack. It commonly makes use of auxiliary functions:

In [10]:
// Using tail-recursive functions

def lengthTR[A](l: List[A]): Int = 

    @scala.annotation.tailrec
    def step(out: Int, aux: List[A]): Int = 
        if aux == Nil then out 
        else step(out+1, aux.tail)

    step(0, l)


defined [32mfunction[39m [36mlengthTR[39m

In [32]:
// Using tail-recursive functions

def lengthTR[A](l: List[A]): Int = 

    @scala.annotation.tailrec
    def step(out: Int, aux: List[A]): Int = 
        aux match 
            case Nil => out 
            case _ :: t => 
                step(out+1, t)

    step(0, l)


defined [32mfunction[39m [36mlengthTR[39m

In [16]:
// Using tail-recursive functions

def lengthTR[A](l: List[A]): Int = 

    @scala.annotation.tailrec
    def step(out: Int, aux: List[A]): Int = 
        aux match 
            case Nil => out 
            case h :: t =>
                step(???, aux.tail)

    step(???, l)


defined [32mfunction[39m [36mlengthTR[39m

In [13]:
lengthTR(List.fill(1000000)(0))

[36mres13[39m: [32mInt[39m = [32m1000000[39m

We can check the stack-safety problems of non-tail recursive functions by calculating the length of a very big list. We will use the following function, which creates a constant list of given length.

In [None]:
// First, imperatively



In [None]:
// Next, tail-recursively



We can also use the function [`fill`](https://www.scala-lang.org/api/2.13.3/scala/collection/immutable/List$.html#fill[A](n:Int)(elem:=%3EA):CC[A]) of the Scala standard library.

Now, let's calculate the length of a list long enough to blow up the stack, using each of the three implementations:

In [None]:
// Imperatively


In [None]:
// Tail-recursive


In [None]:
// Plain recursive


### Unit testing with `scalatest`

In [17]:
import $ivy.`org.scalatest::scalatest:3.2.16`
import org.scalatest.{Filter => _, _}, flatspec._, matchers._

[32mimport [39m[36m$ivy.$                                
[39m
[32mimport [39m[36morg.scalatest.{Filter => _, _}, flatspec._, matchers._
[39m

From now on, we will also make extensive use of unit testing for the different functions that we implement. And we will use the [`scalatest`](http://www.scalatest.org/) library for that purpose. In particular, for each function we will implement a test catalogue that test it against different test cases. The test catalogue receives the actual function to be tested as a parameter. For instance, this is a possible test class for the `length` function:

In [30]:
class TestLength(lengthF: List[Int] => Int) 
extends AnyFlatSpec with should.Matchers:
    "length" should "work" in:
        lengthF(Nil) shouldBe 0
        lengthF(List(1,2,3)) shouldBe 3
        

defined [32mclass[39m [36mTestLength[39m

In [33]:
run(new TestLength(lengthI))
run(new TestLength(lengthR))
run(new TestLength(lengthTR))

[32mcell30$Helper$TestLength:[0m
[32mlength[0m
[32m- should work[0m
[32mcell30$Helper$TestLength:[0m
[32mlength[0m
[32m- should work[0m
[32mcell30$Helper$TestLength:[0m
[32mlength[0m
[32m- should work[0m


In [28]:
object TestLengthTR extends AnyFlatSpec with should.Matchers:
    "lengthTR" should "work" in:
        lengthTR(Nil) shouldBe 0
        lengthTR(List(1,2,3)) shouldBe 3
        

defined [32mobject[39m [36mTestLengthTR[39m

In [26]:
object TestLengthR extends AnyFlatSpec with should.Matchers:
    "lengthR" should "work" in:
        lengthR(Nil) shouldBe 0
        lengthR(List(1,2,3)) shouldBe 3
        

defined [32mobject[39m [36mTestLengthR[39m

In [27]:
run(TestLengthR)

[32mcell26$Helper$TestLengthR:[0m
[32mlengthR[0m
[32m- should work[0m


In [24]:
object TestLengthI extends AnyFlatSpec with should.Matchers:
    "lengthI" should "work" in:
        lengthI(Nil) shouldBe 0
        lengthI(List(1,2,3)) shouldBe 3
        

defined [32mobject[39m [36mTestLengthI[39m

In [25]:
run(TestLengthI)

[32mcell24$Helper$TestLengthI:[0m
[32mlengthI[0m
[32m- should work[0m


In [18]:
class TestLength(lengthF: List[Int] => Int) extends AnyFlatSpec with should.Matchers:
    "length" should "work" in:
        

defined [32mclass[39m [36mTestLength[39m

The method `shouldBe` is a _matcher_. The scalatest library offers an extensive catalogue of [them](http://www.scalatest.org/user_guide/using_matchers). Similarly, scalatest also support many different [testing styles](http://www.scalatest.org/user_guide/selecting_a_style). The chosen one here was `FlatSpec`. In order to execute the test catalogue we can simply use the scalatest method `run`:

In [None]:
run(TestLength(lengthR))

### Example: adding numbers

Let's implement a function that sums all the numbers of a list.

In [37]:
class TestSum(sum: List[Int] => Int) extends AnyFlatSpec with should.Matchers:
    "sum" should "work" in:
        sum(Nil) shouldBe 0
        sum(List(3)) shouldBe 3
        sum(List(1,2,3,4)) shouldBe 10
        sum(1 :: List(2,3,4)) shouldBe 1+9

defined [32mclass[39m [36mTestSum[39m

In [38]:
def sumR(l: List[Int]): Int = 
    l match 
        case Nil => 0 : Int
        case h :: (t: List[Int]) => 
            val tailSol: Int = sumR(t)
            tailSol + h : Int

defined [32mfunction[39m [36msumR[39m

In [42]:
def sumR(l: List[Int]): Int = 
    l match 
        case Nil => 0
        case h :: t => 
            sumR(t) + h

defined [32mfunction[39m [36msumR[39m

In [39]:
run(TestSum(sumR))

[32mcell37$Helper$TestSum:[0m
[32msum[0m
[32m- should work[0m


In [40]:
// With tail-recursion

// Using tail-recursive functions

def sumTR(l: List[Int]): Int = 

    @scala.annotation.tailrec
    def step(out: Int, aux: List[Int]): Int = 
        aux match 
            case Nil => out 
            case h :: t =>
                step(out + h, aux.tail)

    step(0, l)


defined [32mfunction[39m [36msumTR[39m

In [41]:
run(TestSum(sumTR))

[32mcell37$Helper$TestSum:[0m
[32msum[0m
[32m- should work[0m


### Example: multiplying list elements

Let's multiply the elements of a list. If the list is empty we return the identity element for integers.

In [43]:
class TestProduct(product: List[Int] => Int) extends AnyFlatSpec with should.Matchers:
    "product" should "work" in:
        product(List(1,2,3,4)) shouldBe 24
        product(1 :: List(2,3,4)) shouldBe 1 * 24 
        product(List(3)) shouldBe 3
        product(Nil) shouldBe 1
        product(List(1,2,3,0,5,6)) shouldBe 0

defined [32mclass[39m [36mTestProduct[39m

 This is the common recursive implementation:

In [44]:
def product(l: List[Int]): Int = 
    l match 
        case Nil => ??? : Int
        case h :: t => 
            val tailSol: Int = product(t)
            ??? : Int

defined [32mfunction[39m [36mproduct[39m

In [45]:
def product(l: List[Int]): Int = 
    l match 
        case Nil => 1 : Int
        case h :: t => 
            val tailSol: Int = product(t)
            h * tailSol : Int

defined [32mfunction[39m [36mproduct[39m

In [47]:
def product(l: List[Int]): Int = 
    l match 
        case Nil => 1 : Int
        case h :: t => 
            if h == 0 then 0
            else 
                val tailSol: Int = product(t)
                h * tailSol : Int

defined [32mfunction[39m [36mproduct[39m

In [51]:
def product(l: List[Int]): Int = 
    l match 
        case Nil => 1 : Int
        case 0 :: t => 0
        // case h :: _ if h == 0 => 0
        case h :: t =>  
            val tailSol: Int = product(t)
            h * tailSol : Int

defined [32mfunction[39m [36mproduct[39m

In [48]:
run(TestProduct(product))

[32mcell43$Helper$TestProduct:[0m
[32mproduct[0m
[32m- should work[0m


But we can optimize the function a little bit. Note that if the number 0 belongs to the list, then the result is 0, no matter how many elements the list has. So, once we find the element 0 it's a waste of resources to make the recursive call. Let's take this into account.

In [None]:
// optimization for 0



In [None]:
run(TestProduct(product2))

A similar optimization can be made for the tail-recursive implementation.

### Example: membership

Let's implement a function that given a list and an element, returns whether the element belongs to that list.

In [52]:
class TestMember(member: (List[Int], Int) => Boolean) extends AnyFlatSpec with should.Matchers:
    "member" should "work" in:
        member(List(1,2,3), 3) shouldBe true
        member(List(1,2,3), 0) shouldBe false
        member(Nil, 9) shouldBe false

defined [32mclass[39m [36mTestMember[39m

In [60]:
def memberR(l: List[Int], e: Int): Boolean = 
    l match 
        case Nil => false : Boolean 
        case h :: t => 
            val tailSol: Boolean = memberR(t, e)
            tailSol || (h == e) : Boolean
            // (h == e) || tailSol : Boolean

defined [32mfunction[39m [36mmemberR[39m

In [63]:
def memberR(l: List[Int], e: Int): Boolean = 
    l match 
        case Nil => false : Boolean 
        case h :: t => 
            memberR(t, e) || (h == e) : Boolean
            (h == e) || memberR(t, e) : Boolean

defined [32mfunction[39m [36mmemberR[39m

In [65]:
def memberR(l: List[Int], e: Int): Boolean = 
    l match 
        case Nil => false : Boolean 
        case h :: t if h == e => true
        case h :: t => memberR(t, e)

defined [32mfunction[39m [36mmemberR[39m

In [73]:
def memberR[A](l: List[A], e: A): Boolean = 
    l match 
        case Nil => false : Boolean 
        case `e` :: t => true
        case h :: t => memberR(t, e)

defined [32mfunction[39m [36mmemberR[39m

In [75]:
run(TestMember(memberR[Int]))

[32mcell52$Helper$TestMember:[0m
[32mmember[0m
[32m- should work[0m


In [77]:
run(TestMember(memberR))

[32mcell52$Helper$TestMember:[0m
[32mmember[0m
[32m- should work[0m


We can also pattern match against a specific value as follows:

### Example: last element

Let's implement a function that returns the last element of a given list. Note that an empty list does not have elements, and, hence, does not have a last element.

In [78]:
class TestLast(last: List[Int] => Option[Int]) extends AnyFlatSpec with should.Matchers:
    "last" should "work" in:
        last(List(1,2,3)) shouldBe Some(3)
        last(List(1)) shouldBe Some(1)
        last(Nil) shouldBe None

defined [32mclass[39m [36mTestLast[39m

In [80]:
def last[A](l: List[A]): Option[A] = 
    @annotation.tailrec
    def step(out: Option[A], aux: List[A]): Option[A] = 
        aux match 
            case Nil => out
            case h :: t => 
                step( Some(h) /*???(out, h)*/, t)

    step(None, l)

defined [32mfunction[39m [36mlast[39m

In [81]:
run(TestLast(last))

[32mcell78$Helper$TestLast:[0m
[32mlast[0m
[32m- should work[0m


### Example: insert last

Now, a function that allows us to insert an element at the end of the list. 

In [None]:
class TestInsertLast(insertLast: (List[Int], Int) => List[Int]) 
extends AnyFlatSpec with should.Matchers:
    "insertLast" should "work" in:
        ???

In [None]:
run(TestInsertLast(insertLast))

### Example: reverse lists

Implement a function which receives a list and returns its reverse.

In [82]:
class TestReverse(reverse: List[Int] => List[Int]) extends AnyFlatSpec with should.Matchers:
    "reverse" should "work" in:
        ???

defined [32mclass[39m [36mTestReverse[39m

In [None]:
// Really inefficient 



In [None]:
run(TestReverse(reverse))

In [None]:
// Tail-recursive, efficiently



In [None]:
run(TestReverse(reverseTR))

### Example: concatenate lists

In [None]:
class TestConcatenate(concatenate: (List[Int], List[Int]) => List[Int]) 
extends AnyFlatSpec with should.Matchers:
    "concatenate" should "work" in:
        ???

In [None]:
run(TestConcatenate(concatenate))

Tail-recursive concatenation:

In [None]:
run(TestConcatenate(concatenateTR))