Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Latest commit

 

History

History
715 lines (556 loc) · 28.7 KB

readme.md

File metadata and controls

715 lines (556 loc) · 28.7 KB

Simple Monads in Lean

Lean is a pure Functional programming language. This means all inputs and outputs must be described in the function arguments and return types and no other side effects are allowed to happen and this is enforced by the Lean compiler. This allows the Lean Compiler to do some very interesting things, like proofs.

Pure functions are easy to test, they are completely predictable, less prone to bugs and compose incredibly well into larger functions. The following is an example of a pure function:

def divide (x: Float) (y: Float) : Float :=
  x / y

This function takes two floating point numbers and returns the division x divided by y:

#eval divide 6 2    -- 3.000000

But, real world programs usually need to have some side effects, like reading and writing files, terminal IO, networking, logging, exception handling, and reading or writing some sort of global settings, or state. Monads can solve this problem, and they can be used for other useful things like metaprogramming.

More generally monads are useful when you want to introduce a new concept to your programs without messing up the clarity, composability and maintainability of your pure functions.

For example, suppose your Lean program must not allow divide by zero because that might cause your variables to get tainted with Float.inf and let's pretend that for your particular application that would be a huge problem (there are many real world applications where this is the case).

But you don't want to stop using the Float because that would then mean you lose the nice benefits of the system provided Float type. So how can you get the compiler to ensure that Float.inf never happens in your program?

There is a Monad defined in lean that adds exception handling behavior as an add on feature and the way you do it is to add to your return type information about the exception handling behavior your function might have:

def divide (x: Float ) (y: Float): ExceptT String Id Float :=
  if y == 0 then
    throw "can't divide by zero"
  else
    pure (x / y)

#eval divide 5 0  -- Except.error "can't divide by zero"

Here the throw function is available because our return type includes ExceptT. So throw is not available everywhere like it is in most imperative programming languages. You have to declare your function can throw by including the ExceptT type in your result type. This creates a function that might return an error of type String or it might return a value of type Float in the non-error case.

Once your function is monadic you also need to use the pure constructor of the ExceptT monad to convert the floating point return value x / y into the Except object.

This return typing would get tedious if you had to include it everywhere that you call this function, however, Lean type inference can clean this up. For example, you can define a test function can calls the divide function and you don't need to say anything here about the fact that it might throw an error, because that is inferred:

def test :=
  divide 5 0

The Lean compiler propagates the type information for you:

#check test     -- ExceptT String Id PUnit

And now we can run this test and get the expected exception:

#eval test      -- Except.error "can't divide by zero"

But with all good exception handling you also want to be able to catch exceptions so your program can continue on, which you can do like this:

def testCatch := do
  try
    let r ← divide 8 0
    return toString r
  catch e =>
    return s!"Caught exception: {e}"

Note that the type inferred by Lean for this function is ExceptT String Id String so the ExceptT String Id Float return type from the divide has been transformed. The ok type changed from Float to String. This is called "monad transformation" and is what the T stands for in ExceptT. The secret to Lean is how easily it does monad transformation for you in most cases. Notice here you didn't have to do any extra work for the compiler to figure out the transform you were trying to do.

You can now see the try/catch working in this eval:

#eval testCatch -- Except.ok "Caught exception: can't divide by zero"

Notice the Caught exception: wrapped message is returned, and that it is returned as an Except.ok value, meaning testCatch eliminated the error result as expected.

So we've interleaved a new concept into our functions (exception handling) and the compiler is still able to type check everything just as well as it does for pure functions and it's been able to infer some things along the way to make it every easier to manage.

Under the covers

So what really just happened under the covers? Exceptions start with this inductive type:

inductive Except (ε : Type u) (α : Type v) where
  | error : ε → Except ε α
  | ok    : α → Except ε α

Notice this is very generic it can represent an error case where the error is any type ε or an ok case where the ok value is any type α. So the type Except String Float represents an Except type that has a string in the error case or a floating point value in the ok case.

This Except type is then turned into a Monad by declaring this Monad type instance:

instance : Monad (Except ε) where
  pure := Except.pure
  bind := Except.bind
  map  := Except.map

The ExceptT function uses a monad m to transform the type Except ε α:

def ExceptT (ε : Type u) (m : Type u → Type v) (α : Type u) : Type v :=
  m (Except ε α)

This takes an error type ε, a monad m, and the ok type α and uses the monad m to transform the type Except ε α to create a new return type, whatever type is defined by the monad we choose to use here. The T in ExceptT is short for "transformer", so ExceptT is a monad based type transformer.

side note
Now remember that in Lean, any function f (x) (y) (z) can be turned into compositional subfunctions, so f x y is a function that returns a function that takes a z and f x is a function that returns a function that takes y z and so on. This is a type of currying. This means:
  • ExceptT String is a monad transformer.
  • ExceptT String Id is a monad.
  • ExceptT String Id Float is a monadic action in the monad ExceptT String Id which produces a Float when you call the run method on that action.

Yes ExceptT also provides a run method as follows:

def ExceptT.run {ε : Type u} {m : Type u → Type v} {α : Type u} (x : ExceptT ε m α) : m (Except ε α) := x

Now the #eval command can implicitly call run for you in some cases, so when we did #eval test we were really doing this:

#eval test.run

Now the divide function is using the Id monad which is the identity transform so the return type in this case will be unchanged Except String Float. This Id monad might seem a bit weird right now, but it is a placeholder that will allow us to chain monads, which we will do later.

So the divide function can return an Exception object containing an error of type String or an ok result of type Float - which is exactly what we wanted.

The divide function also used this pure function which is defined as :

namespace Except
def pure {ε : Type u} (a : α) : Except ε α :=
  Except.ok a

So in the case of Except String Float the implicit error type ε is String and the pure value is a Float and the Except.pure implementation then simply uses the Except.ok constructor passing the pure value to be wrapped in an Except object. So pure (x / y) converts the pure value x / y into something that matches the return type Except String Float.

All this is built on the Monad type class which is defined as follows:

class Monad (m : Type u → Type v) extends Applicative m, Bind m : Type (max (u+1) v) where
  map      f x := bind x (Function.comp pure f)
  seq      f x := bind f fun y => Functor.map y (x ())
  seqLeft  x y := bind x fun a => bind (y ()) (fun _ => pure a)
  seqRight x y := bind x fun _ => y ()

Monad

Remember that type classes in Lean allow the compiler to do type inference using the declared instances. So the instance Monad (Except ε) allows the compiler to invoke Except.pure when we call the pure method, and Except.bind if we call bind and so on.

Notice the instance : Monad (Except ε) doesn't define the ok type (α : Type v). If you #check (Except String) you will find you get a function that operates on types, in other words a monad:

Except String : Type u_1 → Type (max 0 u_1)

Notice this matches the inputs to the Monad class: (m : Type u → Type v) and is why you can think of a Monad as something that transforms types.

The next method to consider is the Monad bind method which is defined on the Bind type class as:

class Bind (m : Type u → Type v) where
  /-- If `x : m α` and `f : α → m β`, then `x >>= f : m β` represents the
  result of executing `x` to get a value of type `α` and then passing it to `f`. -/
  bind : {α β : Type u} → m α → (α → m β) → m β

Here you can see that bind is using a function to transform the return type from m α to m β and is specialized in the case of the Except.bind as follows:

namespace Except
@[inline] protected def bind (ma : Except ε α) (f : α → Except ε β) : Except ε β :=
  match ma with
  | Except.error err => Except.error err
  | Except.ok v      => f v

So this bind function can be used to transform the type Except String Float to Except String String. First we need a function that takes Float → Except String String and it then unwraps the given Except String Float into its error and ok cases, passing the error through unchanged as Except.error err and using the function f to transform the ok variable v into a string by applying f v which returns a new type Except String String.

This transformation happened automatically in the testCatch function earlier because we used the do notation which is a powerful tool that can chain monad actions finding and applying the right bind operations automatically when needed. In the testCatch function the following line of code shows this in action:

    return toString r -- ExceptT String Id String

Here the toString function was composed into something that contructs an Except.ok String result. So this monad type inference and composition of binding operations is pretty powerful.

Now you might be wondering why testCatch doesn't have return type String? Lean does this as a convenience since you could have a rethrow in or after the catch block. If you really want to stop the ExceptT type from bubbling up you can do this:

def testUnwrap : String := Id.run do
    let r := divide 8 0 -- r is type ExceptT String Id Float = Except String Float
    match r with
    | .ok a => toString a -- 'a' is type Float
    | .error e => s!"Caught exception: {e}"

#check testUnwrap -- String
#eval testUnwrap -- "Caught exception: can't divide by zero"

Alternatively you could solve this using coercions, although that is not always recommended.

instance : Coe (ExceptT α Id α) α where
  coe a := match a with
    | .ok v => v
    | .error v => v

def testCoerce : String :=
  let act : ExceptT String Id String := do
    let r ← divide 8 0
    return r.toString
  act

#check testCoerce -- String
#eval testCoerce  -- "can't divide by zero"

You can also use bind manually if you want to control how it works which we'll see below.

Monad Composition

This is great, but how do you add another dimension to your program using monads? Well it turns out in Lean monads compose very nicely, their side effects can be chained.

Suppose now you want to add some logging to your program so you know how many times divide succeeds without throwing an exception. Logging is very useful in large complex programs to figure out what is really happening.

You have probably already used the IO monad to do terminal IO like IO.println "Hello, world!" but that's not the kind of logging we want here. Sometimes you need something more structured, and more light weight, and easier to consume programmatically. So let's create a counter that is simply incremented every time divide succeeds and pass that "logging state" into our program so you can then also read that state when the program is finished.

There is a monad already defined for this kind of stateful side effect, it is called StateT:

def StateT (σ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
  σ → m (α × σ)

which has some additional functions available through MonadStateOf:

instance [Monad m] : MonadStateOf σ (StateT σ m) where
  get       := StateT.get
  set       := StateT.set
  modifyGet := StateT.modifyGet

Notice it provides a get function to read the state, set function to update it, and a modifyGet that does a read and an update.

If your "context state" is a simple natural number - the count of the number of times divide succeeds -- then you could create a divide function that logs this state information as follows:

def divideLog (x: Float ) (y: Float): StateT Nat Id Float :=
  if y == 0 then do
    return 0
  else do
    modify fun s => s + 1
    return x / y

But how does adding a return type of StateT allow stateful "inputs" to be passed to the divideIt function? How can a return type add an input? You can use "currying" again check the reduced type:

#check divideLog              -- Float → Float → StateT Nat Id Float
#reduce StateT Nat Id Float   -- Nat → Float × Nat

So effectively StateT has changed our function into:

Float → Float → Nat → Float × Nat

So StateT has turned the return type into a function that takes a natural number as input and returns the updated state in the pair Float × Nat. So it has essentially then added an input parameter to our divideLog function. This parameter can be accessed by the StateT interface get, and modifyGet. modify is a helper that makes it easier to use modifyGet.

The following 3 ways of calling this function are equivalent:

#eval divideLog 8 4 0         -- (2.000000, 1)
#eval (divideLog 8 4).run 0   -- (2.000000, 1)
#eval divideLog 8 4 |>.run 0  -- (2.000000, 1)

The first way of simply tacking on the state argument is not recommended because it makes your code hard to maintain if the divideLog parameters change or the StateT monad is changed and so on. it is better to use the run method explicitly which you can do using parentheses, but if you have lots of monads in your chain the right associative operator '|>' is more convenient as it drops the need for parentheses.

An added benefit of calling run explicitly is that the Lean type checker will always ensure for you that the result of divideLog 8 4 is a type that has some run method and so on, if this is not the case it will highlight the right section of your code instead of giving a confusing big error on your entire application that contains some weird monad stack mess. So this is the recommended pattern.

You can call this function in a test function, passing in the state on each call, and storing the updates in a mutable state variable and with the nested for loops this will divide by zero exactly 10 times, which means the result of successful divides should be 90:

def testDivideLog := do
  let mut state := 0
  for x in [0:10] do
    for y in [0:10] do
      let r ← (divideLog x.toFloat y.toFloat) |>.run state
      state := r.2
  state

#eval testDivideLog -- 90

Great, a completely different example of adding an orthogonal dimension to our code (logging). But now what if we want both logging and exception handling? Well you can chain StateT, and ExceptT -- this is why the *T monad transformers take a monad as input. We were passing Id before, but now we can pass the monad ExceptT String Id instead resulting in the return type StateT Nat (ExceptT String Id) Float! Phew, that's a mouth full. Lean allows very sophisticated type construction.

So this means StateT will transform Except String Float into some new return type, in this case it will become Nat → Except String (Float × Nat) because we need the to be able to pass in the state and get back the modified state, so it returns the division result, and the updated state as a tuple.

You can now use bind manually to chain the 2 monadic functions, in this case (modify from StateT and pure from ExceptT) and the bind function on StateT is defined as:

@[inline] protected def bind (x : StateT σ m α) (f : α → StateT σ m β) : StateT σ m β :=
  fun s => do let (a, s) ← x s; f a s

So it returns the pair (a, s) which means in this case β is inferred to be the type Float × Nat. Here's the combined monadic function using a manual bind:

def divideIt (x:Float) (y:Float) : StateT Nat (ExceptT String Id) Float :=
  if y == 0 then
    throw "can't divide by zero"
  else
    bind (modify fun s => s + 1) (fun _ => pure (x / y))

#check divideIt -- Float → Float → StateT Nat (ExceptT String Id) Float
#eval divideIt 5 2 |>.run 0 -- Except.ok (5.000000, 1)

The run parameter passed here is the initial value of the StateT Nat being passed in. The function incremented this state and returned it as the second member of the pair (5.000000, 1).

You can test this new composite divideIt function in a very similar way to testDivideLog and you can add a try/catch so the test doesn't stop when it hits a divide by zero:

def testIt := do
  let mut log := 0
  for x in [0:10] do
    for y in [0:10] do
      try
        let r ← divideIt x.toFloat y.toFloat |>.run log
        log := r.2
      catch _  =>
        pure ()
  pure log

#eval testIt -- 90

Notice here the extracted value in r is the pair Float × Nat so r.2 then is the updated state which we want to keep around so we get the running tally of the number of times we did a successful divide and we get the same result: 90 good divides.

Ok, now to bring it all together, you don't need to use bind manually like this because the do notation can chain monadic actions using bind automatically, so you can rewrite the divideIt function as:

def divideDo (x:Float) (y:Float) : (StateT Nat (ExceptT String Id)) Float := do
  if y == 0 then
    throw "can't divide by zero"
  else
    modify fun s => s + 1
    pure (x / y)

So here the do Notation DSL generated the code bind (modify fun s => s + 1) (fun _ => pure (x / y)) for you. Pretty neat. Note that we used the do notation in divideLog to do some chaining also.

So an imperative program can be modelled in a functional language as a chain of monadic actions and this is a major innovation in the Lean language.

Add one more for fun!

ReaderT is like StateT but it is read only, so it is ideal for "context" or "global state". We can use it to pass around our command line arguments so different parts of our program can behave differently as a result of those arguments. ReaderT provides the additional function read to access that read only context.

Let's first see the manual binding so you get a better idea of how they compose:

def divideWithArgs (x:Float) (y:Float) : ReaderT (List String) (StateT Nat (ExceptT String Id)) Float :=
  bind (modify fun s => s + 1) fun _ =>
    bind (get) fun s =>
      bind (read) fun args =>
        if (s > 10 && args.contains "--limit") then
          throw "too many divides"
        else if y == 0 then
          throw "can't divide by zero"
        else
          pure (x / y)

/-
List String → Nat → Except String (Float × Nat)
-/
#reduce ReaderT (List String) (StateT Nat (ExceptT String Id)) Float

#eval divideWithArgs 5 2 |>.run [] |>.run 0 -- Except.ok (2.500000, 1)
#eval divideWithArgs 5 0 |>.run [] |>.run 0 -- Except.error "can't divide by zero"
#eval divideWithArgs 5 2 |>.run ["--limit"] |>.run 10 -- Except.error "too many divides"

Notice that because we have added 2 monads now, ReaderT and StateT we want to see 2 run method calls.

Fortunately, tThe do Notation cleans this up very nicely:

def divideWithArgsDo (x:Float) (y:Float) : ReaderT (List String) (StateT Nat (ExceptT String Id)) Float := do
  modify fun s => s + 1
  let s ← get
  let args ← read
  if (s > 10 && args.contains "--limit") then
    throw "too many divides"
  else if y == 0 then
    throw "can't divide by zero"
  else
    pure (x / y)

#eval divideWithArgsDo 5 2 |>.run [] |>.run 0 -- Except.ok (2.500000, 1)
#eval divideWithArgsDo 5 0 |>.run [] |>.run 0 -- Except.error "can't divide by zero"
#eval divideWithArgsDo 5 2 |>.run["--limit"] |>.run 10 -- Except.error "too many divides"

Oooh, isn't that loverly. You can even prove that these functions are equivalent:

example : divideWithArgs x y = divideWithArgsDo x y := by
  simp[divideWithArgs, divideWithArgsDo]    -- Goals accomplished 🎉

Monad Lifting

An important part of any program is functional decomposition, breaking large functions up into smaller ones. When you do that you don't always want the smaller functions to have the big complex return types of the outer function. Let's take a look at an example. Remember our first divide function that throws on divide by zero?

def divide (x: Float ) (y: Float): ExceptT String Id Float :=
  if y == 0 then
    throw "can't divide by zero"
  else
    pure (x / y)

Well we can reuse this smaller function in our divideWithArgsDo with some refactoring like this:

def divideRefactored (x:Float) (y:Float) : ReaderT (List String) (StateT Nat (ExceptT String Id)) Float := do
  modify fun s => s + 1
  let s ← get
  let args ← read
  if (s > 10 && args.contains "--limit") then
    throw "too many divides"
  else
    divide x y

#eval divideRefactored 5 2 |>.run [] |>.run 0 -- Except.ok (2.500000, 1)
#eval divideRefactored 5 0 |>.run [] |>.run 0 -- Except.error "can't divide by zero"
#eval divideRefactored 5 2 |>.run ["--limit"] |>.run 10 -- Except.error "too many divides"

Very cool - but some magic happened here. The smaller divide function has a different return type ExceptT String Id Float yet you returned it's value no problem and the compiler turned that into ReaderT (List String) (StateT Nat (ExceptT String Id)) Float for you, somehow.

This is called "monad lifting" and is another secret sauce that Lean provides in order to make monads super easy to use. You could imagine manual monad lifting would be very tedious indeed. You can see this in action with the following test:

def lift1 (x : ExceptT String Id Float) : (StateT Nat (ExceptT String Id)) Float :=
  x

#eval lift1 (divide 5 1) |>.run 3 -- Except.ok (5.000000, 3)

You can see this is lifted to StateT because we were able to then .run that with initial state 3 and the StateT monad converted that into the pair (5.000000, 3). lift1 didn't modify the state so it came back to us unmodified. Now we need a second lift to ReaderT:

def lift2 (x : StateT Nat (ExceptT String Id) Float) : ReaderT (List String) (StateT Nat (ExceptT String Id)) Float :=
  x

#eval lift2 (lift1 (divide 5 1)) |>.run ["discarded", "state"] |>.run 4 -- Except.ok (5.000000, 4)

So you can see how the lifts compose nicely, we can pass in the ReaderT args, and the initial state, and we get back the divide result and the returned state. In this case lift2 does nothing with the ReaderT args so they are discarded.

So what Lean did for you in divideRefactored is a transitive closure of monad lifting operations!

Lifting Deep Dive

Let's see how that works. If you #print lift1 you will see it is implemented as fun x => liftM x and liftM is an abbreviation for monadLift from MonadLiftT.

class MonadLiftT (m : Type u → Type v) (n : Type u → Type w) where
  /-- Lifts a value from monad `m` into monad `n`. -/
  monadLift : {α : Type u} → m α → n α

The T in MonadLiftT stands for "transitive" it is able to transitively lift monadic computations using MonadLift which is a function for lifting a computation from an inner Monad to an outer Monad.

So now we can check all the MonadLift instances defined in Lean, and in our case we will be using:

instance : MonadLift m (StateT σ m)  -- to lift to StateT
instance : MonadLift m (ReaderT ρ m) -- to lift to ReaderT

These instances override the lift function for these types, showing the compiler how to generate that code. Let's see how that works for StateT:

@[inline] protected def lift {α : Type u} (t : m α) : StateT σ m α :=
  fun s => do let a ← t; pure (a, s)

So this is very generic, given some implicit type α and a monad m that acts on α it is able to generate the return type StateT σ m α by returning a function that takes some state s from which it can then create the pair (a, s) where a is the result of applying the monad m α. And this is what we saw, the state 3 passed in resulted in a pair coming back out as (5.000000, 3). It is not inventing the state, it is lifting that state so it is an input and an output resulting in a valid StateT Nat (ExceptT String Id).

Similarly, for ReaderT we find:

instance  : MonadLift m (ReaderT ρ m) where
  monadLift x := fun _ => x

This one is a bit simpler, remember ReaderT is about passing in some read-only state, but to lift from something that does not know about ReaderT then the thing we are calling obviously doesn't care about this read only state, so we can throw it away using the underscore _ and simply call the inner function, which is what we saw with this eval:

#eval lift2 (lift1 (divide 5 1)) |>.run ["discarded", "state"] |>.run 4 -- Except.ok (5.000000, 4)

The ReaderT state here is ["discarded", "state"] and it is thrown away when calling lift1 because lift1 doesn't know (or care) about ReaderT.

So looking at divideRefactored again, you get an appreciation for what is going on under the covers to make that monadic code nice and composable, both on the way in with monads like ReaderT and StateT adding additional input parameters, and on the way out with automatic transitive monad lifting. Lift happens very often in Lean.

Add your own Custom Lifting

You can now build the app with lake build and try out this main function:

def main (args: List String): IO Unit := do
  try
    let ret ← divideRefactored 5 0 |>.run args |>.run 10
    IO.println (toString ret)
  catch e =>
    IO.println e

This function would not normally compile saying:

typeclass instance problem is stuck, it is often due to metavariables

and you can see this error if you add this line before the main function:

set_option autoLift false

divideRefactored returns the big ReaderT (List String) (StateT Nat (ExceptT String Id)) Float and the problem is that cannot be automatically transformed into the main return type of IO Unit unless we give it some help.

The following custom MonadLift solves this problem:

def liftIO (t : ExceptT String Id α) : IO α := do
  match t with
  | .ok r => EStateM.Result.ok r
  | .error s => EStateM.Result.error s

instance : MonadLift (ExceptT String Id) IO where
  monadLift := liftIO

This instance makes it possible to lift the result of type ExceptT String Id into the type required by main which is IO Unit. Fortunately this lift is relatively easy because IO is just an alias for the EStateM.Result which is very similar to the Except object in that it also has an ok or error state. The difference is Result has one more data member, which is a return code.

If we have an instance MonadLiftT m n that means there is a way to turn a computation that happens inside of m into one that happens inside of n and (this is the key part) usually without the instance itself creating any additional data that feeds into the computation. This means you can in principle declare lifting instances from any monad to any other monad, it does not, however, mean that you should do this in all cases. You can get a report from Lean of how all this was done by uncommenting the line set_option trace.Meta.synthInstance true in before main and moving the cursor to the end of the first line after do and you will see a nice detailed report.

Now lake build will create a binary named simpleMonads which you can run from the command line and you can pass command line parameters as follows:

> simpleMonads
can't divide by zero

> simpleMonads --limit
too many divides

So we were able to influence the behavior of our program by passing some command line arguments and some logging state, and we added some exception handling and we did it all in a purely functional way using monads. Then we also showed how monad lifing makes functional decomposition nice and manageable.