Scala collections 101 and beyond 
=================

## using Scala in competitive programming


#### Tamás Kurics

Data scientist at Twinner

Programming background:

* Spent many years in research (applied maths, scientific computing, math. biology) using MATLAB (2000-2012)
* learnt QuickBASIC for a semester :-)
* Started to use Python (writing small scripts) (2013)
* Using Python for data science (2015-now)
* The product I have worked on at Balabit was rewritten from Python to Scala, which was a game changer 
* Converted Scala user (2016-2018 professionally, hobbyist since 2019, interested in FP-related concepts)
* Currently working on object detection and segmentation problems in the automotive industry (in Python)

How to learn a language?

Learning by doing (in your freetime):

* Project Euler (cca. 2013-2014 in Matlab)
* HackerRank (2016-now in Python, Scala, Haskell), top 100 in the FP-track
* Rosalind (2016-2017 in Python, Scala)
* Codility, etc (occasionally)

### Transition from Python to Scala

Python is a good choice for data science tasks:

* large community
* wide range of libraries for each possible tasks
* a library for tabular data (pandas)
* a library for numerical linear algebra (numpy)
* a small set of well-designed collections

But even if it is easy, one has to learn how the basic features of the language are actually implemented (```lst.append(x)``` or ```lst.insert(0, x)```).

```python
def consume(lst):
    while lst:
        _ = lst.pop(0)

def consume2(lst):
    while lst:
        _ = lst.pop()
```

```python
matrix = [[0] * 2] * 3
matrix[0][0] = 100

total = 0
for row in matrix:
    for elem in row:
        total += elem

print(total)
```

Every language comes with 'surprises' (a *surprise* is a language feature that the programmer thinks he/she is aware of but the reality might be different).

Default Python collections:

* linear collections
  * dynamic array: **list**, $O(1)$ get and set, append is amortized $O(1)$
  * immutable linear collection: **tuple**
* **string**s  
* hashable elements:  
  * unique hashable elements: **set**, **frozenset**
  * key-value pairs: **dict**
* some other things from libraries: **heapq**, **deque**,...

In 2016 we have decided to rewrite our algorithms in Scala. Let's see what possibilities we have to store elements:

* List
* Vector
* Array
* String
* Set
* Map

And some others ...

::                    HashSet               IntMapKeyIterator     List                  LongMapEntryIterator  Map              Page                  Set                   Stream                StringOps             Vector
AbstractMap           IndexedSeq            IntMapUtils           ListMap               LongMapIterator       MapLike               Queue                 SetProxy              StreamIterator        Traversable           VectorBuilder
BitSet                IntMap                IntMapValueIterator   ListSerializeEnd      LongMapKeyIterator    MapProxy              Range                 SortedMap             StreamView            TreeMap               VectorIterator
DefaultMap            IntMapEntryIterator   Iterable              ListSet               LongMapUtils          Nil                    RedBlackTree          SortedSet             StreamViewLike        TreeSet               VectorPointer
HashMap               IntMapIterator        LinearSeq             LongMap               LongMapValueIterator    NumericRange          Seq                   Stack                 StringLike            TrieIterator            WrappedString

AbstractBuffer             ArrayLike                  DefaultEntry               History                    LinkedEntry                MapProxy                   Queue                      SetLike                    SynchronizedPriorityQueue  WrappedArray
AbstractIterable           ArrayOps                   DefaultMapModel            ImmutableMapAdaptor        LinkedHashMap              MultiMap                   QueueProxy                 SetProxy                   SynchronizedQueue          WrappedArrayBuilder
AbstractMap                ArraySeq                   DoubleLinkedListLike       ImmutableSetAdaptor        LinkedHashSet              MutableList                RedBlackTree               SortedMap                  SynchronizedSet
AbstractSeq                ArrayStack                 DoublingUnrolledBuffer     IndexedSeq                 LinkedListLike             ObservableBuffer           ResizableArray             SortedSet                  SynchronizedStack
AbstractSet                BitSet                     FlatHashTable              IndexedSeqLike             ListBuffer                 ObservableMap              ReusableBuilder            Stack                      Traversable
AbstractSortedMap          Buffer                     GrowingBuilder             IndexedSeqOptimized        ListMap                    ObservableSet              RevertibleHistory          StackProxy                 TreeMap
AbstractSortedSet          BufferLike                 HashEntry                  IndexedSeqView             LongMap                    OpenHashMap                Seq                        StringBuilder              TreeSet
AnyRefMap                  BufferProxy                HashMap                    Iterable                   Map                        PriorityQueue              SeqLike                    Subscriber                 Undoable
ArrayBuffer                Builder                    HashSet                    LazyBuilder                MapBuilder                 PriorityQueueProxy         Set                        SynchronizedBuffer         UnrolledBuffer
ArrayBuilder               Cloneable                  HashTable                  LinearSeq                  MapLike                    Publisher                  SetBuilder                 SynchronizedMap            WeakHashMap

In [None]:
import $ivy.`com.storm-enroute::scalameter-core:0.8.2`

In [None]:
import org.scalameter._

object Config {
  val standardConfig: MeasureBuilder[Unit, Double] = config(
    Key.exec.minWarmupRuns -> 5,
    Key.exec.maxWarmupRuns -> 10,
    Key.exec.benchRuns -> 10,
    Key.verbose -> true
  ) withWarmer new Warmer.Default
}

In [None]:
import scala.annotation.tailrec
import scala.util.Random
import Config.standardConfig

In [None]:
scala.util.Properties.versionString

In [None]:
val list: List[Int] = List(2, 4, 1, 3, 6, 5, 0, 9, 7, 8)

In [None]:
list.isEmpty

list.head

list.tail

All the 3 methods run in $O(1)$, although partial functions are not necessarily the best things to use.

Count the number of even elements in a list. In Python, there is only one way to do it: the most Pythonic solution.

In [None]:
def countEvens1(lst: List[Int]): Int = {
  @tailrec
  def loop(xs: List[Int], acc: Int): Int =
    if (xs.isEmpty) acc
    else if (xs.head % 2 == 0) loop(xs.tail, acc + 1)
    else loop(xs.tail, acc)

  loop(lst, 0)
}

def countEvens2(lst: List[Int]): Int = {
  @tailrec
  def loop(xs: List[Int], acc: Int): Int = xs match {
    case Nil => acc
    case x :: xss => if (x % 2 == 0) loop(xss, acc + 1) else loop(xss, acc)
  }
  loop(lst, 0)
}

In [None]:
def countEvens3(lst: List[Int]): Int =
  lst.foldLeft(0)((acc, elem) => if (elem % 2 == 0) acc + 1 else acc)


def countEvens4(lst: List[Int]): Int =
  lst.foldRight(0)((elem, acc) => if (elem % 2 == 0) acc + 1 else acc)


def countEvens5(lst: List[Int]): Int = lst.count(_ % 2 == 0)


def countEvens6(lst: List[Int]): Int = {
  var count: Int = 0
  for {
    elem <- lst
    if elem % 2 == 0
  } count += 1
  count
}

In [None]:
def generateIntList(size: Int): List[Int] = {
   val r = new Random(42)
    (0 until size).map(_ => r.nextInt()).toList
  }

In [None]:
val size: Int = 1000000
val list: List[Int] = generateIntList(size)

In [None]:
val time: Quantity[Double] = standardConfig.measure {
  countEvens5(list)
}
println(time)

Methods such as ```isEmpty, head, tail``` are so fundamental that many collections implement these.

But what do they mean? 
* what is the head of a Set?
* what is the tail of a Vector?
* What do we get if we take the first $k$ elements from a PriorityQueue?

Can we assume that all these methods have the same meaning and same computational complexity for all collections where they are implemented?

### Problem 1 - Sequence of full colors (Easy)

You are given a sequence of balls in 4 colors: red, green, yellow and blue. The sequence is full of colors if and only if all of the following conditions are true:

* There are as many red balls as green balls.
* There are as many yellow balls as blue balls.
* Difference between the number of red balls and green balls in every prefix of the sequence is at most 1.
* Difference between the number of yellow balls and blue balls in every prefix of the sequence is at most 1.

Your task is to write a program, which for a given sequence prints True if it is full of colors, otherwise it prints False.

In [None]:
def hasFullColors(sequence: String): Boolean = {
  def prefixCondition(colors: Map[Char, Int]): Boolean =
    math.abs(colors('R') - colors('G')) <= 1 && math.abs(colors('B') - colors('Y')) <= 1

  @tailrec
  def loop(acc: Map[Char, Int], colors: String): Boolean = {
    if (colors.isEmpty) prefixCondition(acc) && acc('R') == acc('G') && acc('B') == acc('Y')
    else {
      val firstCharacter: Char = colors.head
      val updatedColors: Map[Char, Int] = acc.updated(firstCharacter, acc(firstCharacter) + 1)
      if (prefixCondition(updatedColors)) loop(updatedColors, colors.tail)
      else false
    }
  }
  
  loop(Map('R' -> 0, 'G' -> 0, 'B' -> 0, 'Y' -> 0), sequence)
}

In [None]:
def generateRandomString(size: Int): String = {
  val r = new Random(42L)
  val colors: Vector[Char] = Vector('R', 'G', 'B', 'Y')
  val n: Int = colors.length
  (for {
    _ <- 0 until size
    ix = r.nextInt(n)
  } yield colors(ix)).mkString
}

val colors: String = generateRandomString(10)

In [None]:
val colors: String = generateRandomString(1000000)

val t0 = System.nanoTime()
val res: Boolean = hasFullColors(colors)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

In [None]:
def generateTestString(size: Int): String = {
  val colors: List[Char] = List('R', 'G', 'B', 'Y')
  (0 until size / 4 + 1).flatMap(_ => colors).take(size).mkString
}

val colors: String = generateTestString(400000)
//val colors: String = generateTestString(400000)

val t0 = System.nanoTime()
val res: Boolean = hasFullColors(colors)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

In [None]:
def hasFullColors2(sequence: String): Boolean = {
    def prefixCondition(colors: Map[Char, Int]): Boolean =
      math.abs(colors('R') - colors('G')) <= 1 && math.abs(colors('B') - colors('Y')) <= 1

    @tailrec
    def loop(acc: Map[Char, Int], colors: List[Char]): Boolean = {
      if (colors.isEmpty) prefixCondition(acc) && acc('R') == acc('G') && acc('B') == acc('Y')
      else {
        val firstCharacter: Char = colors.head
        val updatedColors: Map[Char, Int] = acc.updated(firstCharacter, acc(firstCharacter) + 1)
        if (prefixCondition(updatedColors)) loop(updatedColors, colors.tail)
        else false
      }
    }

    loop(Map('R' -> 0, 'G' -> 0, 'B' -> 0, 'Y' -> 0), sequence.toList)
  }

In [None]:
val colors: String = generateTestString(2000000)

val t0 = System.nanoTime()
val res: Boolean = hasFullColors2(colors)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

### Problem 2 - Pattern matching (finding strings in a text)
#### A walk around Maps

Run Matching problem in IntelliJIDEA

In [None]:
val m: Map[String, Int] = Map("a" -> 1, "b" -> 2)

In [None]:
def func(x: Int): Int = {
  Thread.sleep(3000)
  x + 1
}

In [None]:
val updated: Map[String, Int] = m.map{ case (key, value) => (key, func(value))}

In [None]:
updated

In [None]:
val updated2: Map[String, Int] = m.mapValues(func)

In [None]:
updated2

In [None]:
updated2("a")

In [None]:
val updated3: Map[String, Int] = m.mapValues(func).view.force

In [None]:
updated3

In [None]:
trait CoinState
case object Head extends CoinState
case object Tail extends CoinState

val r = new Random(42)

val distribution: Map[CoinState, Double] = Map(Head -> 0.5, Tail -> 0.5)

In [None]:
val perturbation: Map[CoinState, Double] = distribution.mapValues(p => p + 0.1 * r.nextGaussian())
val sum: Double = perturbation.valuesIterator.sum
val updated: Map[CoinState, Double] = perturbation.mapValues(_ / sum)

### Problem 3 - Range Minimum Query

Given an array of integers $a_1, a_2, \dots,a_N$ and $K$ queries.
* Q l r : Find the minimum positive integer $M$, such that each element in subarray $[a_l,a_{l+1},\dots,a_r]$ divides $M$.
* U i v : Multiply $a_{i}$ by $v$ 

Constraints:
$1\leq N \leq 2\cdot 10^4$, $1 \leq a_i\leq 100$, $1\leq K \leq 2\cdot 10^4$, $1\leq v\leq 100$

For each query of type Q l r, print the value of $M$ on a new line (modulo $10^9 + 7$)).

If 
$$ n = p_1^{\alpha_1}p_2^{\alpha_2}\cdots p_k^{\alpha_k},$$
then $n$ can be represented as a Map: $$p \to \alpha.$$ Then multiplication and least common multiples can be calculated easily without Int/Long overflow:

$$
n \cdot m = p_1^{\alpha_1 + \beta_1}p_2^{\alpha_2 + \beta_2}\cdots p_k^{\alpha_k + \beta_k},
$$
and
$$
\text{lcm}(n, m) =  p_1^{\max\{\alpha_1, \beta_1\}}p_2^{\max\{\alpha_2, \beta_2\}}\cdots p_k^{\max\{\alpha_k, \beta_k\}}
$$

In [None]:
import scala.collection.mutable.ListBuffer

val UpperLimit: Int = 100

def getPrimes(limit: Int): List[Int] = {
  def isPrime(n: Int, primes: Iterator[Int]): Boolean = primes.takeWhile(_ <= math.sqrt(n)).forall(p => n % p != 0)
  val primes: ListBuffer[Int] = ListBuffer(2)
  for {
    p <- 3 to limit by 2
    if isPrime(p, primes.iterator)
  } primes += p

  primes.toList
}

val Primes: List[Int] = getPrimes(UpperLimit)

def findLargestExponent(n: Int, prime: Int): (Int, Int) = {
  @tailrec
  def loop(k: Int, exponent: Int): (Int, Int) = {
    if (k % prime != 0) (k, exponent)
    else loop(k / prime, exponent + 1)
  }
  loop(n, 0)
}

In [None]:
import scala.collection.breakOut

val Modulus: Int = 1e9.toInt + 7

case class CanonicalForm(primeFactorization: Map[Int, Int]) {
  lazy val asInt: Int = primeFactorization.foldLeft(1L) {
    case (acc, (p, exponent)) => (0 until exponent).foldLeft(acc)((prod, _) => prod * p % Modulus)
  }.toInt

  def * (that: CanonicalForm): CanonicalForm = {
    val primes: Set[Int] = this.primeFactorization.keySet.union(that.primeFactorization.keySet)
    val product: Map[Int, Int] = 
      primes.map(
          p => p -> (this.primeFactorization.getOrElse(p, 0) + 
                     that.primeFactorization.getOrElse(p, 0)))(breakOut)
    CanonicalForm(product)
  }
}

object CanonicalForm {
  def apply(n: Int): CanonicalForm = {
    @tailrec
    def loop(k: Int, primes: List[Int], factorization: Map[Int, Int]): CanonicalForm = primes match {
      case Nil if k > 1 => throw new Exception(s"$n is not factorized!")
      case Nil => CanonicalForm(factorization)
      case p :: ps =>
        if (k == 1) CanonicalForm(factorization)
        else {
          val (remainder, exponent): (Int, Int) = findLargestExponent(k, p)
          if (exponent == 0) loop(remainder, ps, factorization)
          else loop(remainder, ps, factorization + (p -> exponent))
        }
    }
    loop(n, Primes, Map())
  }
}

In [None]:
val x = CanonicalForm(2 * 3 * 5 * 11)
val y = CanonicalForm(13 * 23)
val z = CanonicalForm(5 * 5 * 17 * 13)

In [None]:
x.asInt

In [None]:
val xs : List[CanonicalForm] = (0 until 500000).flatMap{ _ => List(x, y, z) }.toList

In [None]:
val t0 = System.nanoTime()
val _: CanonicalForm = xs.foldLeft(CanonicalForm(1))((acc, elem) => acc * elem)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

In [None]:
import scala.collection.immutable.HashMap

val Modulus: Int = 1e9.toInt + 7

final case class CanonicalForm(primeFactorization: HashMap[Int, Int]) {
  lazy val asInt: Int = primeFactorization.foldLeft(1L) {
    case (acc, (p, exponent)) => (0 until exponent).foldLeft(acc)((prod, _) => prod * p % Modulus)
  }.toInt

  def * (that: CanonicalForm): CanonicalForm =
    CanonicalForm(this.primeFactorization.merged(that.primeFactorization){ 
        case ((p, e1), (_, e2)) => (p, e1 + e2) 
    })
}

object CanonicalForm {
  def apply(n: Int): CanonicalForm = {
    @tailrec
    def loop(k: Int, primes: List[Int], factorization: HashMap[Int, Int]): CanonicalForm = primes match {
      case Nil if k > 1 => throw new Exception(s"$n is not factorized!")
      case Nil => CanonicalForm(factorization)
      case p :: ps =>
        if (k == 1) CanonicalForm(factorization)
        else {
          val (remainder, exponent): (Int, Int) = findLargestExponent(k, p)
          if (exponent == 0) loop(remainder, ps, factorization)
          else loop(remainder, ps, factorization + (p -> exponent))
        }
    }
    loop(n, Primes, HashMap())
  }
}

What is returned by the ```Map``` constructor?

In [None]:
val m: Map[String, Int] = Map("a" -> 1, "b" -> 2)

m.isInstanceOf[HashMap[String, Int]]

In [None]:
val m: Map[String, Int] = Map("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4, "e" -> 5)

m.isInstanceOf[HashMap[String, Int]]

In [None]:
val m1: HashMap[String, Int] = HashMap("a" -> 1, "b" -> 2)
val m2: HashMap[String, Int] = HashMap("a" -> 10, "c" -> 20)

m1.merged(m2){ case ((k1, v1), (k2, v2)) => (k1 + k2, v1 + v2) }

In [None]:
val x = CanonicalForm(2 * 3 * 5 * 11)
val y = CanonicalForm(13 * 23)
val z = CanonicalForm(5 * 5 * 17 * 13)

val xs : List[CanonicalForm] = (0 until 500000).flatMap{ _ => List(x, y, z) }.toList

In [None]:
val t0 = System.nanoTime()
val _: CanonicalForm = xs.foldLeft(CanonicalForm(1))((acc, elem) => acc * elem)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

### Problem 4 - Running median of integers (Hard)

* Input: a list of integers $a_1, a_2, \dots,a_n$ where $ n\leq 10^5$.
* Output: a list of medians where the $k^{\text th}$ element is the median of the elements $a_1,\dots,a_k$.

This is the simplified version of a problem called Messy Medians.

In [None]:
def calcRunningMedians1(numbers: List[Int]): List[Int] = {
    @tailrec
    def loop(medians: List[Int], prefix: Vector[Int], ns: List[Int]): List[Int] = ns match {
      case Nil => medians.reverse
      case n :: nss =>
        val updated: Vector[Int] = n +: prefix
        val size: Int = updated.size
        val sorted: Vector[Int] = updated.sorted
        val m: Int =
          if (size % 2 == 1) sorted(size / 2)
          else (sorted(size / 2 - 1) + sorted(size / 2)) / 2
        loop(m :: medians, sorted, nss)
    }
  loop(Nil, Vector(), numbers)
}

In [None]:
val numbers: List[Int] = generateIntList(20000)

val t0 = System.nanoTime()
val res: List[Int] = calcRunningMedians1(numbers)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

This is an $O(n^2\log n)$ solution.

In [None]:
import scala.collection.mutable.{PriorityQueue => Heap}

val heap = Heap(1, 3, 2, 6, 7, 9, 8, 0, 4)

What are these things and how fast they are?

In [None]:
heap.head

heap.max

heap.take(4)

In [None]:
val list: List[Int] = generateIntList(1000000)
val heap = Heap(list: _*)

val time: Quantity[Double] = standardConfig.measure {
  //heap.max
  heap.head
}
println(time)

In [None]:
def calcMedian(maxHeap: Heap[Int], minHeap: Heap[Int]): Int = {
  val s1: Int = maxHeap.size
  val s2: Int = minHeap.size
  if (s1 > s2) maxHeap.head
  else if (s2 > s1) minHeap.head
  else (maxHeap.head + minHeap.head) / 2
}

def addToHeap(n: Int, currentMedian: Int, maxHeap: Heap[Int], minHeap: Heap[Int]): Unit = {
  if (n < currentMedian) maxHeap.enqueue(n)
  else minHeap.enqueue(n)
}

def balanceHeaps(maxHeap: Heap[Int], minHeap: Heap[Int]): Unit = {
  val s1: Int = maxHeap.size
  val s2: Int = minHeap.size
  if (s1 > s2 + 1) {
    val largestLeft: Int = maxHeap.dequeue()
    minHeap.enqueue(largestLeft)
  } else if (s2 > s1 + 1) {
    val smallestRight: Int = minHeap.dequeue()
    maxHeap.enqueue(smallestRight)
  }
}

In [None]:
def calcRunningMedians2(numbers: List[Int]): List[Int] = {
  val maxHeap = Heap.empty[Int]
  val minHeap = Heap.empty[Int](Ordering[Int].reverse)

  @tailrec
  def loop(acc: List[Int], currentMedian: Int, ns: List[Int]): List[Int] = ns match {
    case Nil => acc.reverse
    case n :: nss =>
      addToHeap(n, currentMedian, maxHeap, minHeap)
      balanceHeaps(maxHeap, minHeap)
      val m: Int = calcMedian(maxHeap, minHeap)
      loop(m :: acc, m, nss)
  }

  numbers match {
    case Nil => Nil
    case first :: rest =>
      minHeap.enqueue(first)
      loop(List(first), first, rest)
  }
}

In [None]:
val numbers: List[Int] = generateIntList(1000000)

val t0 = System.nanoTime()
val res: List[Int] = calcRunningMedians2(numbers)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)

A heap is mutable, its inner representation is a resizeable array. 

Is there a persistent data structure that stores its element sorted and provides fast enough methods to extract its middle elements?

In [None]:
import scala.collection.immutable.TreeSet

val set = TreeSet(1, 4, 3, 5, 7, 8, 2, 2, 1, 1, 1)

How can we create a multiset?

In [None]:
val multiset = TreeSet((1, 0), (4, 1), (3, 2), (5, 3), (7, 4), (8, 5), (2, 6), (2, 7), (1, 8), (1, 9), (1, 10))

In [None]:
def calcRunningMedians3(numbers: List[Int]): List[Int] = {
  @tailrec
  def loop(treeSet: TreeSet[(Int, Int)], medians: List[Int], ns: List[(Int, Int)]): List[Int] = ns match {
    case Nil => medians.reverse
    case (n, ix) :: nss =>
      val updatedSet: TreeSet[(Int, Int)] = treeSet.insert((n, ix))
      val size: Int = updatedSet.size
      val m: Int =
        if (size % 2 == 1) {
          val List(elem) = updatedSet.slice(size / 2, size / 2 + 1).toList.map(_._1)
          elem
        }
        else {
          val List(a, b) = updatedSet.slice(size / 2 - 1, size / 2 + 1).toList.map(_._1)
          (a + b) / 2
        }
      loop(updatedSet, m :: medians, nss)
  }

  loop(TreeSet.empty[(Int, Int)], Nil, numbers.zipWithIndex)
}

In [None]:
val numbers: List[Int] = generateIntList(100000)

val t0 = System.nanoTime()
val res: List[Int] = calcRunningMedians3(numbers)
val t1 = System.nanoTime()
println((t1 - t0) / 1e9)