# Recursive data types and functions

## Recursive types

### The `List` type

Lists are data structures which represent sequences of values of the same type, of finite length. They can be defined recursively in an informal way as follows: 
- A list is the empty sequence
- A list is a non-empty sequence made of a value and another list, which represent the head and tail of the list, respectively

Thus, the type `IntList`, which represents lists of integers, must satisfy the following equation:

`IntList = 1 + Int * IntList`

i.e., a list of integers is the empty sequence (represented by the singleton type `1`), or an integer (the head) and a list (its tail).



The implementation in Scala is as follows (we also give the generic version `List[A]`, rather than the implementation of `IntList`):

In [None]:
sealed abstract class List[A]
case class NonEmpty[A](head: A, tail: List[A]) extends List[A]
case class Empty[A]() extends List[A]

Note that the actual implementation of [immutable lists](https://github.com/scala/scala/blob/v2.13.1/src/library/scala/collection/immutable/List.scala#L79) in the standard library of Scala defines the empty list as an object, rather than a class. However, this forces us to declare the list covariantly in its generic parameter `A`, which is somewhat inconvenient at times.  The standard definition looks as follows:

In [None]:
object AlternativeDefinition{
    sealed abstract class List[+A]
    case class NonEmpty[A](head: A, tail: List[A]) extends List[A]
    case object Empty extends List[Nothing]
}

We will stick to the former definition. Some examples of lists: 

In [None]:
// The empty list
val l0: List[Int] = Empty()

In [None]:
// Non-empty list [1, 2, 3]
val l1: List[Int] = NonEmpty(1, NonEmpty(2, NonEmpty(3, Empty())))

### Some syntactic sugar

Note that we can write standard lists with a more compact syntax: 

In [None]:
import scala.{List => IList}

val l2: IList[Int] = 1 :: 2 :: 3 :: Nil
val l3: IList[Int] = IList(1,2,3)

How can we do that with out own lists? We define a smart constructor in the companion object using variadic arguments: 

In [None]:
object List{
    def apply[A](elem: A*): List[A] = 
        if (elem.isEmpty) Empty()
        else NonEmpty(elem.head, apply(elem.tail: _*))
}

In [None]:
val l1: List[Int] = List(1,2,3)

Note that the smart constructor `apply` is defined recursively. Let's dive into recursion.

##  Recursive functions

Since lists are defined recursively, functions over lists will be commonly recursive as well. For instance, let's implement a recursive function that computes the length of a list. But before, let's implement the function imperatively for the sake of comparison:

In [None]:
// Using mutable variables

def lengthI[A](list: List[A]): Int = {
    var acc: Int = 0
    var aux: List[A] = list
    while (aux != Empty()){
        aux = aux.asInstanceOf[NonEmpty[A]].tail
        acc += 1
    }
    acc
}

In [None]:
lengthI(List())
lengthI(List(1,2,3,4))

The recursive function is implemented as follows: 

In [None]:
// Using recursive functions

def lengthR[A](list: List[A]): Int = 
    list match {
        case Empty()           => 0
        case NonEmpty(_, tail) => 1 + lengthR(tail)
    }

Some comments: 
- The recursive function is implemented in a _type-driven development_ style: we proceed, step-by-step, analysing the types of input data that we have available so far, and the types of output that we have to generate. This leads to a divide-and-conquer problem solving strategy and hugely facilitates the implementation.
- The recursive function is less efficient, since the stack will blow up with lists of enough lenght.

### Tail-recursive functions

The implementation using tail-recursion solves the problems with the stack. It commonly makes use of auxiliary functions:

In [None]:
// Using tail-recursive functions

def lengthTR[A](list: List[A]): Int = {

    @annotation.tailrec
    def lengthAux(acc: Int, aux: List[A]): Int = 
        aux match {
            case Empty() => acc
            case NonEmpty(_, tail) => lengthAux(acc+1, tail)
        }
    
    lengthAux(0, list)
}

In [None]:
lengthTR(List())
lengthTR(List(1,2,3))

We can check the stack-safety problems of non-tail recursive functions, by calculating the length of a very big list. We will use the following function, which creates a constant list of given length.

In [None]:
// First, imperatively

def constantList[A](value: A, length: Int): List[A] = {
    var acc: List[A] = Empty()
    for (i <- 1 to length)
        acc = NonEmpty(value, acc)
    acc
}

In [None]:
// Next, tail-recursively

def constantList[A](value: A, length: Int): List[A] = {

    def constantAux(acc: List[A], i: Int): List[A] = 
        if (i == 0) acc
        else constantAux(NonEmpty(value, acc), i-1)
    
    constantAux(Empty(), length)
}

Now, let's calculate the length of a list long enough, using each of the three implementations:

In [None]:
// Imperatively
lengthI(constantList(0, 100000))

In [None]:
// Tail-recursive
lengthTR(constantList(0, 100000))

In [None]:
// Plain recursive
lengthR(constantList(0, 100000))

### Example: adding numbers

Let's implement a function that sums all the numbers of a list.

In [None]:
// Recursively

def sum(list: List[Int]): Int = 
    list match {
        case Empty() => 0 : Int
        case NonEmpty(head, tail) => head + sum(tail) : Int 
    }

In [None]:
sum(List(1,2,3,4))

In [None]:
// With tail-recursion

def sum(list: List[Int]): Int = {

    def sumAux(acc: Int, list: List[Int]): Int = 
        list match {
            case Empty() => acc : Int
            case NonEmpty(head, tail) => sumAux(head + acc, tail) : Int 
        }
    
    sumAux(0, list)
}

In [None]:
sum(List(1,2,3))

### Example: multiplying list elements

Let's multiply the elements of a list. If the list is empty we return the identity element for integers. This is the common recursive implementation:

In [None]:
def product(list: List[Int]): Int = 
    list match {
        case Empty() => 1
        case NonEmpty(head, tail) => 
            head * product(tail)
    }

It works as expected: 

In [None]:
assert(product(List(1,2,3)) == 6)
assert(product(List(1,2,0,3,4,5,6,6)) == 0)

But we can optimize the function a little bit. Note that if the number 0 belongs to the list, then the result is 0, no matter how many elements the list has. So, once we find the element 0 it's a waste of resources to make the recursive call. Let's take this into account.

In [None]:
def product(list: List[Int]): Int = 
    list match {
        case Empty() => 1
        // We add this extra case
        case NonEmpty(0, _) => 0
        case NonEmpty(head, tail) => head * product(tail)
    }

A similar optimization can be made for the tail-recursive implementation.

### Example: membership

Let's implement a function that given a list and an element, returns whether the element belongs to that list.

In [None]:
def member[A](list: List[A], elem: A): Boolean = 
    list match {
        case Empty() => false
        case NonEmpty(head, tail) => 
            head == elem || member(tail, elem)
    }

In [None]:
assert(member(List(1,2,3), 0) == false)

In [None]:
assert(member(List(1,2,3), 2) == true)

In [None]:
assert(member(List(1,2,3), 4) == false)

We can also pattern match against a specific value as follows:

In [None]:
def member[A](list: List[A], elem: A): Boolean = 
    list match {
        case Empty() => false
        case NonEmpty(`elem`, _) => true
        case NonEmpty(_, tail) => member(tail, elem)
    }

In [None]:
assert(member(List(1,2,3), 1) == true)

In [None]:
assert(member(List(), 1) == false)

In [None]:
assert(member(List(1,2,3), 4) == false)

### Example: last element

Let's implement a function that returns the last element of a given list. Note that an empty list does not have elements, and, hence, does not have a last element.

In [None]:
// Recursively 

@annotation.tailrec
def last[A](list: List[A]): Option[A] =
    list match {
        case Empty() => None
        case NonEmpty(head, Empty()) => Some(head)
        case NonEmpty(head, tail) => last(tail)
    }

In [None]:
assert(last(List()) == None)

In [None]:
assert(last(List(1, 2)) == Some(2))

### Example: insert last

Now, a function that allows us to insert an element at the end of the list. 

In [None]:
def insertLast[A](list: List[A], elem: A): List[A] = 
    list match {
        case Empty() => List(elem)
        case NonEmpty(head, tail) => 
            NonEmpty(head, insertLast(tail, elem))
    }

In [None]:
assert(insertLast(List(), 1) == List(1))

In [None]:
assert(insertLast(List(1,2,3), 0) == List(1,2,3,0))

### Example: concatenate lists

Let's implement this function step-by-step, following the types. We start from the signature of the desired function:

In [None]:
def concatenate[A](list1: List[A], list2: List[A]): List[A] = ???

1. Pattern match on `list1`:

In [None]:
def concatenate[A](list1: List[A], list2: List[A]): List[A] =
    list1 match {
        case Empty() => ??? : List[A]
        case NonEmpty(head, tail) => ??? : List[A]
    }

2. Solve empty case:

In [None]:
def concatenate[A](list1: List[A], list2: List[A]): List[A] =
    list1 match {
        case Empty() => list2 : List[A]
        case NonEmpty(head, tail) => ??? : List[A]
    }

3. Solve non-empty case:

In [None]:
def concatenate[A](list1: List[A], list2: List[A]): List[A] =
    list1 match {
        case Empty() => list2 : List[A]
        case NonEmpty(head, tail) => 
            NonEmpty(head, concatenate(tail, list2)) : List[A]
    }

### Example: reverse lists

In [None]:
// Really inefficient 

def reverseR[A](list: List[A]): List[A] = 
    list match {
        case Empty() => Empty()
        case NonEmpty(head, tail) =>
                concatenate(reverseR(tail), NonEmpty(head, Empty()))
    }

In [None]:
// Tail-recursive, efficiently

def reverseTR[A](list: List[A]): List[A] = {
    def reverseAux(acc: List[A], list: List[A]): List[A] = 
        list match {
            case Empty() => acc
            case NonEmpty(head, tail) => 
                reverseAux(NonEmpty(head, acc), tail)
        }
    
    reverseAux(Empty(), list)
}

In [None]:
assert(reverseTR(List(1,2,3)) == List(3,2,1))

### Example: tail-recursive concatenation

In [None]:
def concatenate[A](list1: List[A], list2: List[A]): List[A] = {

    def concAux(acc: List[A], list: List[A]): List[A] = 
        list match {
            case Empty() => acc
            case NonEmpty(head, tail) => 
                concAux(NonEmpty(head, acc), tail)
        }
    
    concAux(Empty(), concAux(concAux(Empty(), list1), list2))
}

In [None]:
val l1 = List(1,2,3)

assert(concatenate(l1, reverseTR(l1)) == List(1,2,3,3,2,1))

### Example: drop elements

Implement a function that drops the first _n_ elements of a list. If the number of elements to be dropped is lower than 0, then the same list must be returned.

In [None]:
def drop[A](list: List[A], n: Int): List[A] = 
    (list, n) match {
        case (NonEmpty(_, tail), n) if n > 0 => 
            drop(tail, n-1)
        case (list, _) => 
            list
    }

In [None]:
assert(drop(List(1,2,3,4,5,6), 3) == List(4,5,6))

In [None]:
assert(drop(List(1,2,3), 5) == List())

In [None]:
assert(drop(List(1,2,3), 0) == List(1,2,3))