<center> 

<h1 style="text-align:center"> Monads </h1>
<h2 style="text-align:center"> CSCI7000-11 S23: Principles of Functional Programming </h2>
</center>

## Whence Monads

* The term "monad" come from **Category Theory**
  + Category theory is the study of mathematical abstractions
  + Out of scope for this course
  + We will focus on **programming with monads**.

## Monads for programming

* Monads were popularized by the Haskell programming language
  + Haskell is **purely functional** programming languages
  + Unlike OCaml, Haskell separates pure code from side-effecting code through the use of monads.
* Monads are a way to *simulate* and *encapsulate* effects in a pure setting
  + ... similar to how we simulated advanced language features in lambda calculus encodings.
* Monad is an _idiom_ / _a design pattern_
  + not a primitive language feature

## What is a Monad?

A monad is any implementation that satisfies the following signature:

In [None]:
module type Monad = sig
  type 'a t                                 (* computation *)
  val return : 'a -> 'a t                   (* lift a value to a computation *)
  val bind   : 'a t -> ('a -> 'b t) -> 'b t (* sequence two computations *)
end

and the **monad laws**.

## Is that it?

* All of this seems **very abstract** (as many FP concepts are).
* An example will help us see the pattern.
  + Overtime, you'll spot monads everywhere.
* Let's write an interpreter for artihmetic expressions

## Interpreting artihmetic expressions

In [None]:
type expr = 
| Val of int 
| Plus of expr * expr 
| Div of expr * expr

## Interpreting artihmetic expressions

* Our goal is to make the interpreter a **total function**.
  + Produces a **value** for every arithmetic expression.

In [None]:
let rec eval e = match e with
  | Val v -> v
  | Plus (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 + v2
  | Div (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 / v2

## Interpreting arithmetic expressions : examples

In [None]:
eval (Plus (Div (Val 4, Val 2), Val 7)) (* 4 / 2 + 7 *)

## Division by zero

This looks fine. But what happens if the denominator in the division is a 0.

In [None]:
eval (Div (Val 1, Val 0))

* Recall that our goal is to make the interpreter a **total function**
  + Due to exceptions, the function is not total.

How can we avoid this?

## Interpreting Arithmetic Expressions: Take 2

* Rewrite `eval` function to have the type `expr -> int option`
  + Return `None` for division by zero.

In [None]:
let rec eval e = match e with
  | Val v -> Some v
  | Plus (e1,e2) ->
      begin match eval e1 with 
      | None -> None
      | Some v1 -> 
          match eval e2 with
          | None -> None 
          | Some v2 -> Some (v1 + v2)
      end
  | Div (e1,e2) ->
      begin match eval e1 with 
      | None -> None
      | Some v1 -> 
          match eval e2 with
          | None -> None 
          | Some v2 -> if v2 = 0 then None else Some (v1 / v2)
      end

## Interpreting Arithmetic Expressions: Take 2

In [None]:
eval (Plus (Div (Val 4, Val 2), Val 7)) (* 4 / 2 + 7 *)

In [None]:
eval (Div (Val 1, Val 0)) (* 1 / 0 *)

## Abstraction

* There is a lot of repeated code in the interpreter above.
  + Factor out common code.

In [None]:
let return v = Some v

In [None]:
let bind m f = match m with
  | None -> None 
  | Some v -> f v

**Convention:** Using the names `return` and `bind` below because I am defining a monad. But you could have alternatively picked any name. 

## Abstraction 

Let's rewrite the interpreter using these functions.

```ocaml
let return v = Some v

let bind m f = match m with
  | None -> None 
  | Some v -> f v
```

In [None]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      bind (eval e1) (fun v1 -> 
      bind (eval e2) (fun v2 ->
      return (v1+v2)))
  | Div (e1,e2) ->
      bind (eval e1) (fun v1 -> 
      bind (eval e2) (fun v2 ->
      if v2 = 0 then None else return (v1 / v2)))

This is written in a suggestive way so as to lead onto nice syntax.

## Infix bind operation

Usually `bind` is defined as an infix function `>>=`.

In [None]:
let (>>=) = bind

In [None]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      eval e1 >>= fun v1 -> 
      eval e2 >>= fun v2 ->
      return (v1+v2)
  | Div (e1,e2) ->
      eval e1 >>= fun v1 -> 
      eval e2 >>= fun v2 ->
      if v2 = 0 then None else return (v1 / v2)

## `let*` syntax extension 

Since OCaml 4.08 released in June 2019, there is new syntax for making it easier to write monadic programs.

In [None]:
let ( let* ) = bind

## `let*` syntax extension 


In [None]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      let* v1 = eval e1 in
      let* v2 = eval e2 in
      return (v1+v2)
  | Div (e1,e2) ->
      let* v1 = eval e1 in 
      let* v2 = eval e2 in
      if v2 = 0 then None 
      else return (v1 / v2)

## Compare this to our initial take

```ocaml
let rec eval e = match e with
  | Val v -> v
  | Plus (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 + v2
  | Div (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 / v2
```

There are additional `return` and `let*`, but the overall structure remains the same. 

## Modularise

* The `return` and `let*` we have defined for the interpreter works for any computation on option type. 
  + Put them in a module, we get the **Option Monad**.
* Option monad _simulates_ **exceptions**.

In [None]:
module type MONAD = sig
  type 'a t                                  (* computation *)
  val return  : 'a -> 'a t                   (* lift a value to a computation *)
  val (let*)  : 'a t -> ('a -> 'b t) -> 'b t (* sequence two computations *)
end

module OptionMonad : (MONAD with type 'a t = 'a option) = struct
  type 'a t = 'a option
  let return v = Some v
  let (let*) m f = match m with
  | Some v -> f v
  | None -> None
end

## Monad Laws

Monad laws constrain what the `return` and `>>=` can do.

Any implementation of the monad signature must satisfy the following laws:


```ocaml
1. return v >>= f   ≡  f v   (* Left Identity *)
2. v >>= return     ≡  v     (* Right Identity *)
3. (m >>= f) >>= g  ≡  m >>= (fun x -> f x >>= g) 
                             (* Associativity *)
```

## Option monad satisifies monad laws

**Left Identity**: `return v >>= f  ≡  f v`

```ocaml
  return v >>= f
≡ (Some v) >>= f  (* by definition of return *)
≡ match Some v with 
  | None   -> None 
  | Some v -> f v 
                  (* by definition of >>= *)
≡ f v             (* by beta reduction *)
```

**Exercice:** Prove other laws.

## Simulating state

* Recall, monads simulate **effects** in a **pure** setting.
  + **option** monad simulates **exceptions**
* How can we simulate **mutability**?
  + For a start, a single, typed, mutable location in the whole program.
  + Operations to `get` the current state and `put` a new state.

**Idea:** _Thread_ the state through the program.

_Threading_ the state means passing the state as an addtional function argument to _every_ function and returning the new state along with the function result.

## Threading the state

What does threading the state look like? 

The usual Fibonacci function looks like:

In [None]:
let rec fib n = 
  if n < 2 then 1 
  else fib (n-1) + fib (n-2)

## Threading the state

Here is the Fibonacci function that threads the state through as 

* the last additional argument and 
* returns a pair of the new state and the result of the function

In [None]:
let rec fib n (s (* threaded state *)) = 
  if n < 2 then (s, 1) 
  else 
    let (s1, v1) = fib (n-1) s in
    let (s2, v2) = fib (n-2) s1 in
    (s2, v1 + v2)

The above function neither reads the state nor writes to the state.

## Manipulating the state

In order to read and write the state, we implement the following functions.

In [None]:
let get () (s (* threaded state *)) = (s,s)

let put new_s (_ (* threaded state *)) = (new_s, ())

`get` seems a little bit pointless at the moment, but let's hold on.

## Fibonacci with state

Here is a function that computes the Fibonacci number reading the input from the state and writing the output to the state.

In [None]:
let fib_state () s =
  let (s1,n) = get () s in
  Printf.printf "get: state=%d result=%d\n%!" s1 n;
  
  let (s2,r) = fib n s1 in
  Printf.printf "fib: state=%d result=%d\n%!" s2 r;
  
  let (s3,s) = put r s2 in
  Printf.printf "put: state=%d result=()\n%!" s3;
  
  (s3,())

## Fibonacci with state


In [None]:
fib_state () 10

## Remove tedium

Quite tedious to write functions that explicitly thread the state through (and possibly not even touch it).

**Note:** Using the type variable `state` for the state type.

```ocaml
val fib : int -> state -> state * int

let rec fib n (s (* threaded state *)) = 
  if n < 2 then (s, 1) 
  else 
    let (s1, v1) = fib (n-1) s in
    let (s2, v2) = fib (n-2) s1 in
    (s2, v1 + v2)
```



## Remove tedium

Look at the types:
```ocaml
type state
val fib : int -> state -> (state, int)
val get : state -> (state, state)
val put : state -> state -> (state, unit)
```
Factor out common parts:
```ocaml
type state
type 'a t (* computation type *) = state -> (state, 'a)
val fib : int -> int t
val get : state t
val put : state -> unit t
```
`'a` is the return type of the computation.

## `bind` computations

How to make this better?

```ocaml
...
    let (s1, v1) = fib (n-1) s in
    let (s2, v2) = fib (n-2) s1 in
    (s2, v1 + v2)
```    

Use `bind` to forward the state to the subsequent computation. 


```ocaml
type state
type 'a t = state -> (state, 'a) (* computation *)

let bind (m : 'a t) (f : 'a -> 'b t) : 'b t = 
  fun s (* current state *) ->
    let (s': state, v : 'a) = m s in
    let (s'': state, res: 'b) = f v s' in
    (s'' (* resultant state *), res)
```

## `bind` computations

With `let (let*) = bind`, we get:

```ocaml
...
    let* v1 = fib (n-1) in
    let* v2 = fib (n-2) in
    return (v1 + v2)
```

👍

## State Monad

What we've defined is a state monad.

* A State Monad introduces a **single, typed mutable cell**.
* Offers
  + `get` and `put` functions for reading and writing the state, and 
  + Also includes a `run_state` function for actually computations with initial state.

## State Monad

In [None]:
module type STATE = sig
  type state
  include MONAD
  val get : state t
  val put : state -> unit t
  val run_state : 'a t 
                -> init:state (* labelled argument for initial state *) 
                -> state (* final state *) * 'a
end

## State Monad

Here's an implementation of `State`, parameterised by the type of the state:

In [None]:
module State (S : sig type t end) 
  : STATE with type state = S.t = struct

  type state = S.t
  type 'a t = state -> state * 'a (* computation *)

  let return v = fun s -> (s, v)

  let (let*) m f = fun s -> 
    let (s', a) = m s in 
    f a s'

  let get = fun s -> (s, s)

  let put s' = fun _ -> (s', ())

  let run_state m ~init = m init
end

## Using State Monad

In [None]:
module IntState = State (struct type t = int end)
open IntState 

(* [inc v] increments the state by [v] *)
let inc v = 
  let* s = get in 
  put (s+v)

(* [dec v] decrements the state by [v] *)
let dec v = 
  let* s = get in
  put (s-v)

(* [double] doubles the state *)
let double =
  let* s = get in
  put (s*2)

## Using State Monad

In [None]:
let comp = 
  let* () = inc 20 in
  let* () = double in
  dec 10
in

IntState.run_state ~init:10 comp

In [None]:
module FloatState = State (struct type t = float end)
open FloatState

let comp = 
  let* v = get in 
  let* () = put (v +. 1.0) in
  return "Hello, world"
;;

run_state ~init:5.4 comp

## Fibonacci, again (in a monad)

In [None]:
open State (struct type t = int end)

let rec fib n = 
  if n < 2 then return 1
  else
    let* v1 = fib (n-1) in
    let* v2 = fib (n-2) in
    return (v1 + v2)

let fib_state = 
  let* n = get in
  let* r = fib n in
  put r
;;
  
run_state ~init:10 fib_state

## State monad satisfies monad laws

**Right Associativity**: `v >>= return  ≡  v`

```ocaml
  v >>= return
≡ fun s -> 
    let (s', a) = v s in 
    return a s' (* by definition of >>= *)
≡ fun s -> 
    let (s', a) = v s in 
    (fun v s -> (s,v)) a s' (* by definition of return *)
≡ fun s -> let (s', a) = v s in (s',a) (* by beta reduction *)
≡ fun s -> v s (* by eta reduction *)
≡ v (* by eta reduction *)
```

**Exercise**: Prove other laws.

## Type of State

* State in the state monad is of a single type
  + In our example, the state was of `int` type
* *Can we change type of state as the computation evolves?*

## Parameterised monads

* Parameterised monads add two additional type parameters to `t` representing the start and end states of a computation.
* A computation of type `('p, 'q, 'a) t` has 
  + *precondition* (or starting state) `'p`
  + *postcondition* (or ending state) `'q`
  + *produces a result* of type `'a`.

## Parameterised monads

Here's the parameterised monad signature:

In [None]:
module type PARAMETERISED_MONAD =
sig
  (* A computation on a state of type ['s] returning ['a'] and 
     changing the state to a value of type ['t'] *)
  type ('s,'t,'a) t 
  
  (* Lift a value to a computation that does not change the state type *)
  val return : 'a -> ('s,'s,'a) t
  
  val (let*) : ('r,'s,'a) t ->
       ('a -> ('s,'t,'b) t) ->
              ('r,'t,'b) t
end

## Parameterised state monad

Here's a parameterised monad version of the `STATE` signature, using the extra parameters to represent the type of the reference cell.

In [None]:
module type PSTATE =
sig
 include PARAMETERISED_MONAD
 val get : ('s,'s,'s) t
 val put : 's -> (_,'s,unit) t
 val run_state : ('s,'t,'a) t -> init:'s -> 't * 'a
end

## Parameterised state monad


Here's an implementation of `PSTATE`.

In [None]:
module PState : PSTATE =
struct
  type ('s, 't, 'a) t = 's -> 't * 'a

  let return v = fun s -> (s, v)

  let (let*) m k = fun s -> 
    let t, a = m s in 
    k a t

  let put s = fun _ -> (s, ())

  let get = fun s -> (s, s)

  let run_state m ~init = m init
end

## Computation with changing state

In [None]:
open PState

let inc v = let* s = get in put (s+v)
let dec v = let* s = get in put (s-v)
let double = let* s = get in put (s*2)
  
let to_string = let* i = get in put (string_of_int i)
let of_string = let* s = get in put (int_of_string s)

## Computation with changing state

In [None]:
let foo = let* _ = inc 5 in to_string
let bar = let* s = get in put (s ^ "00")
  
let baz = let* _ = foo in bar

In [None]:
run_state ~init:5 baz

## Computation with changing state

```ocaml
let foo = let* _ = inc 5 in to_string
let bar = let* s = get in put (s ^ "00")
```

In [None]:
let quz = let* _ = bar in foo

## Use-case: A well-typed stack machine

* Let's build a tiny stack machine with 3 instructions
  + `push` pushes a constant on to the stack. Constant could be of any type. 
  + `add` adds the top two integers on the stack and pushes the result
  + `_if_` expects a `[b;v1;v2] @ rest_of_stack` on top of the stack.
      * if `b` is true then result stack will be `v1::rest_of_stack`
      * otherwise, `v2::rest_of_stack`.
* Our stack machine will not get stuck! 
  + recall the definition from lambda calculus lectures
  + For any program, the type of the program will tell you precisely the **shape** of the stack it needs to execute on.
* This is how [WebAssembly](https://webassembly.org/) operational semantics is defined!

## Stack operations

* Because our stack will have values of different types, encode them using pairs.
  + `[]` will be `()`
  + `[1;2;3]` will be `(1, (2, (3, ())))`
  + `[1;true;3]` (which is not a well-typed OCaml expression) will be `(1, (true, (3, ()))))`

## Stack Operations

In [None]:
module type STACK_OPS =
sig
  type ('s,'r,'a) t
  val add : unit -> (int * (int * 's), 
                     int * 's, 
                     unit) t
  val _if_ : unit -> (bool * ('a * ('a * 's)), 
                      'a * 's, 
                      unit) t
  val push_const : 'a -> ('s, 
                          'a * 's, 
                          unit) t
end

## Stack Machine

We can combine the stack operations with the parameterised monad signature to
build a signature for a stack machine:

In [None]:
module type STACKM = sig
 include PARAMETERISED_MONAD
 include STACK_OPS
   with type ('s,'t,'a) t := ('s,'t,'a) t
 val execute : ('s,'t,'a) t -> 's -> 't * 'a
end

## Stack Machine

Here is the implementation of the stack machine

In [None]:
module StackM : STACKM =
struct
  include PState
 
  let add () =
    let* (x,(y,s)) = get in
    put (x+y,s)
 
  let _if_ () =
    let* (c,(t,(e,s))) = get in
    put ((if c then t else e),s)

  let push_const k =
    let* s = get in
    put (k, s)

  let execute c s = run_state ~init:s c
end

## Using the stack machine

In [None]:
let program = let open StackM in
  let* _ = push_const 4 in
  let* _ = push_const 5 in
  let* _ = add () in
  _if_ ()

## Using the stack machine

In [None]:
let program = let open StackM in
  let* _ = push_const 4 in
  let* _ = push_const 5 in
  let* _ = push_const true in
  let* _ = _if_ () in
  add ()

In [None]:
StackM.execute program ()

In [None]:
StackM.execute program (20,(10,()))

## Using the stack machine

In [None]:
StackM.execute (StackM._if_ ()) (false,(10,()))

In [None]:
StackM.execute (StackM.add ()) ()