In [1]:
trait RNG:
    def nextInt: (Int, RNG)

case class SimpleRNG(seed: Long) extends RNG:
    def nextInt: (Int, RNG) =
        val newSeed = (seed * 0x5DeeCE66DL + 0xBL) & 0xFFFFFFFFFFFFL
        val nextRNG = SimpleRNG(newSeed)
        val n = (newSeed >>> 16).toInt
        (n, nextRNG)

defined [32mtrait[39m [36mRNG[39m
defined [32mclass[39m [36mSimpleRNG[39m

#### Exercise 6.1

Write a function that uses `RNG.nextInt` to generate a random integer between `0` and `Int.MaxValue` (inclusive). Make sure to handle the corner case when `nextInt` returns `Int.MinValue`, which doesn't have a nonnegative counterpart.

In [2]:
def nonNegativeInt(rng: RNG): (Int, RNG) =
    val (n, nextRNG) = rng.nextInt
    val absN = if n == Int.MinValue then 0 else scala.math.abs(n)
    (n, nextRNG)

defined [32mfunction[39m [36mnonNegativeInt[39m

### Exercise 6.2

Write a function to generate a `Double` between `0` and `1`, not including `1`. Note that you can use `Int.MaxValue` to obtain the maximum postive integer value, and you can use `x.toDouble` to convert an `x: Int` to a `Double`.

In [3]:
def double(rng: RNG): (Double, RNG) =
    val (n, nextRNG) = nonNegativeInt(rng)
    ((if n > 0 then n-1 else n).toDouble / Int.MaxValue.toDouble, nextRNG)

defined [32mfunction[39m [36mdouble[39m

#### Exercise 6.3

Write functions to generate an `(Int, Double)` pair, a `(Double, Int)` pair, and a `(Double, Double, Double)` 3-tuple. You should be able to reuse the functions you've already written.

In [4]:
def intDouble(rng: RNG): ((Int, Double), RNG) =
    val (i, r) = rng.nextInt
    val (d, r2) = double(rng)
    ((i,d), r2)

def doubleInt(rng: RNG): ((Double, Int), RNG) =
    val (d, r) = double(rng)
    val (i, r2) = r.nextInt
    ((d, i), r2)

def double3(rng: RNG): ((Double, Double, Double), RNG) =
    val (d1, r1) = double(rng)
    val (d2, r2) = double(r1)
    val (d3, r3) = double(r2)
    ((d1, d2, d3), r2)

val r = SimpleRNG(42)
val ((i, d), r2) = intDouble(r)
val ((d2, i2), r3) = doubleInt(r2)
val ((d3, d4, d5), r4) = double3(r3)

(i, d, d2, i2, d3, d4, d5, r4)

defined [32mfunction[39m [36mintDouble[39m
defined [32mfunction[39m [36mdoubleInt[39m
defined [32mfunction[39m [36mdouble3[39m
[36mr[39m: [32mSimpleRNG[39m = [33mSimpleRNG[39m(seed = [32m42L[39m)
[36mr2[39m: [32mRNG[39m = [33mSimpleRNG[39m(seed = [32m1059025964525L[39m)
[36mr3[39m: [32mRNG[39m = [33mSimpleRNG[39m(seed = [32m259172689157871L[39m)
[36mr4[39m: [32mRNG[39m = [33mSimpleRNG[39m(seed = [32m115998806404289L[39m)
[36mres3_7[39m: ([32mInt[39m, [32mDouble[39m, [32mDouble[39m, [32mInt[39m, [32mDouble[39m, [32mDouble[39m, [32mDouble[39m, [32mRNG[39m) = (
  [32m16159453[39m,
  [32m0.007524831224011644[39m,
  [32m-0.5967354856416283[39m,
  [32m-340305902[39m,
  [32m-0.9386595436086224[39m,
  [32m0.8242210921944217[39m,
  [32m-0.900863232044905[39m,
  [33mSimpleRNG[39m(seed = [32m115998806404289L[39m)
)

#### Exercise 6.4

Write a function to generate a list of random integers

In [5]:
def ints(count: Int)(rng: RNG): (List[Int], RNG) =
    @annotation.tailrec
    def go(l: List[Int], count: Int)(rng: RNG): (List[Int], RNG) =
        if count <= 0 then (l, rng)
        else
            val (i, r) = rng.nextInt
            go(i :: l, count-1)(r)
    go(List(), count)(rng)

val r = SimpleRNG(42)
val (is, r2) = ints(10)(r)

defined [32mfunction[39m [36mints[39m
[36mr[39m: [32mSimpleRNG[39m = [33mSimpleRNG[39m(seed = [32m42L[39m)
[36mis[39m: [32mList[39m[[32mInt[39m] = [33mList[39m(
  [32m1837487774[39m,
  [32m-94901159[39m,
  [32m-1163632441[39m,
  [32m1015914512[39m,
  [32m-1934589059[39m,
  [32m1770001318[39m,
  [32m-2015756020[39m,
  [32m-340305902[39m,
  [32m-1281479697[39m,
  [32m16159453[39m
)
[36mr2[39m: [32mRNG[39m = [33mSimpleRNG[39m(seed = [32m120421598792892L[39m)

In [6]:
type Rand[+A] = RNG => (A, RNG)

def map[A, B](s: Rand[A])(f: A => B): Rand[B] =
    rng =>
        val (a, rng2) = s(rng)
        (f(a), rng2)

defined [32mtype[39m [36mRand[39m
defined [32mfunction[39m [36mmap[39m

### Exercise 6.5

Use `map` to reimplement `double` in a more succint way.

In [7]:
def double: Rand[Double] = map(nonNegativeInt)(_ / (Int.MaxValue.toDouble + 1))

defined [32mfunction[39m [36mdouble[39m

In [8]:
double(SimpleRNG(42))

[36mres7[39m: ([32mDouble[39m, [32mRNG[39m) = ([32m0.007524831686168909[39m, [33mSimpleRNG[39m(seed = [32m1059025964525L[39m))

#### Exercise 6.6

Write the implementation of `map2` based on the following signature. This function takes two actions, `ra` and `rb`, and a function, `f`, for combining their results and returns a new action that combines them.

In [9]:
def map2[A, B, C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] =
    rng => 
        val (a, r2) = ra(rng)
        val (b, r3) = rb(r2)
        (f(a, b), r3)

defined [32mfunction[39m [36mmap2[39m

#### Exercise 6.7

If you can combine two RNG actions, you should be able to combine an entire list of them. Implement `sequence` for combining a `List` of actions into a single action. Use it to reimplement the `ints` function you wrote before. For the latter, you can use the standard library function `List.fill(n)(x)` to make a list with `x` repeated `n` times.

In [10]:
def unit[A](a: A): Rand[A] =
    rng => (a, rng)


def sequence[A](rs: List[Rand[A]]): Rand[List[A]] = rs match
    case r :: rs => map2(r, sequence(rs))(_ :: _)
    case _ => unit(List())


defined [32mfunction[39m [36munit[39m
defined [32mfunction[39m [36msequence[39m

In [12]:
def sequence[A](rs: List[Rand[A]]): Rand[List[A]] =
    rs.foldRight(unit(Nil : List[A]))((r, acc) => map2(r, acc)(_ :: _))

defined [32mfunction[39m [36msequence[39m

#### Exercise 6.8

Implement `flatMap`, and then use it to implement `nonNegativeLessThan`.

In [18]:
def flatMap[A, B](r: Rand[A])(f: A => Rand[B]): Rand[B] =
    rng => 
        val (a, r2) = r(rng)
        f(a)(r2)

def nonNegativeLessThan(n: Int): Rand[Int] =
    flatMap(nonNegativeInt) { i =>
        val mod = i % n
        if i + (n-1) - mod >= 0 then unit(mod) else nonNegativeLessThan(n)
    }

defined [32mfunction[39m [36mflatMap[39m
defined [32mfunction[39m [36mnonNegativeLessThan[39m

#### Exercise 6.9

Implement `flatMap`, and then use it to implement `nonNegativeLessThan`.

In [20]:
def map[A, B](r: Rand[A])(f: A => B): Rand[B] =
    flatMap(r)(a => unit(f(a)))

def map2[A, B, C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] =
    flatMap(ra)(a => map(rb)(b => f(a, b)))

defined [32mfunction[39m [36mmap[39m
defined [32mfunction[39m [36mmap2[39m

#### Exercise 6.10

Generalize the functions `unit`, `map`, `map2`, `flatMap`, and `sequence`. Add them as extenion methods on the `State` type where possible. Otherwise, you should put them in the `State` companion object.

In [31]:
object Fps2 {
    opaque type State[S, +A] = S => (A, S)

    object State:
        def apply[S, A](f: S => (A, S)): State[S, A] = f

        def unit[S, A](a: A): State[S, A] =
            s => (a, s)

        def sequence[S, A](s: List[State[S, A]]): State[S, List[A]] =
            s.foldRight(unit(Nil : List[A]))((s, acc) => s.map2(acc)(_ :: _))

        extension [S, A](underlying: State[S, A])
            def run(s: S): (A, S) = underlying(s)

            def flatMap[B](f: A => State[S, B]): State[S, B] =
                s => 
                    val (a, s2) = underlying(s)
                    f(a)(s2)
            
            def map[B](f: A => B): State[S, B] =
                underlying.flatMap(a => unit(f(a)))

            def map2[B, C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
                underlying.flatMap(a => sb.map(b => f(a, b)))
}

import Fps2.State

type Rand[A] = State[RNG, A]

def int : Rand[Int] = State(rng => rng.nextInt)

int.map(_.toString).run(SimpleRNG(42))

int.map2(int)(_ + _).run(SimpleRNG(42))

defined [32mobject[39m [36mFps2[39m
[32mimport [39m[36mFps2.State

[39m
defined [32mtype[39m [36mRand[39m
defined [32mfunction[39m [36mint[39m
[36mres30_4[39m: ([32mString[39m, [32mRNG[39m) = ([32m"16159453"[39m, [33mSimpleRNG[39m(seed = [32m1059025964525L[39m))
[36mres30_5[39m: ([32mInt[39m, [32mRNG[39m) = ([32m-1265320244[39m, [33mSimpleRNG[39m(seed = [32m197491923327988L[39m))