In [1]:
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
import Control.Exception
import Control.Monad.Except
import Control.Monad.State
import qualified Data.Text as T
import GHC.Generics
import Grisette
import Grisette.Lib.Control.Monad.Except
import Grisette.Lib.Control.Monad.Trans.Class
import Grisette.Lib.Control.Monad.Trans.State.Lazy

# Using Monad Transformers with Grisette

## Introduction

In the [previous tutorial](./2_union.ipynb), we discussed how to use the `Union` monad to model multi-path execution. In this tutorial, we'll explore handling failure paths and stateful computations using monads and monad transformers. We'll use `Either` and `ExceptT` to represent and handle errors and use `State` and `StateT` to model stateful computations.

We will extend our expression reasoning tool to support a more expressive language with variables, division, and assignment statements, using monad transformers to add errors and states to our multi-path execution. We will then do some verification and test generation for this language.

In this tutorial, you will learn how to:

- Handle failure paths and stateful computations using monads and monad transformers
- Extend the expression reasoning tool to support a more expressive language with states and exceptions
- Use Grisette's `solveExcept` function to find inputs that trigger specific exceptions or satisfy predicates over the program's result

By the end of this tutorial, you'll have a deeper understanding of handling errors and states in solver-aided applications using monads and monad transformers, and you'll be able to create more expressive and verifiable languages using Grisette.

Please make sure that you have `z3` (https://github.com/Z3Prover/z3) installed and accessible in `$PATH`.

Note that some inline `code blocks` have links to the documentation. It's possible that they are not rendered in a visible way as a link in jupyter notebooks.

## Paths That May Fail

When working with programs that involve complex computations, it's common to encounter scenarios where certain operations may fail. These failures can occur due to various reasons, such as invalid input, or unexpected conditions.

In the context of verification and symbolic execution, handling failure paths is crucial to ensure the correctness of our programs. We need a way to represent and reason about these potential failures in a structured and composable manner.

In Haskell, a common approach to handling failure paths is to use the `Either` type. The `Either` type is defined as follows:

```haskell
data Either e a
  = Left e
  | Right a
```

It represents a value that can be either a `Left` value of type `e` or a `Right` value of type `a`.
By convention, the `Left` constructor is used to represent failure or error cases, while the `Right` constructor is used to represent successful computations.

For example, let's implement a safe division function for machine integers (signed bit vectors):

In [2]:
safeDiv :: Int -> Int -> Either ArithException Int
safeDiv x 0 = Left DivideByZero
safeDiv x (-1) | x == minBound = Left Overflow
safeDiv x y = Right (x `div` y)

safeDiv 5 2            -- Right 2
safeDiv 5 0            -- Left divide by zero
safeDiv minBound (-1)  -- Left arithmetic overflow

Right 2

Left divide by zero

Left arithmetic overflow

In this function, we use `Either` to handle invalid inputs. The function returns a `DivideByZero` error when dividing by zero, or an `Overflow` error when dividing `minBound` by `-1` (since `-minBound` is not representable as a 2's complement integer).
This allows us to handle the errors and record what has happened in a purely functional way, without resorting to exceptions or null values.

`Left` and `Right` could also be replaced by `throwError` and `return`. These are more general operations for all the monads that could handle failures, and we will discuss them later.

### Short-Circuiting the Error

Unsurprisingly, `Either e` is a monad.
One of the key benefits of using `Either` in combination with monadic operations is the ability to sequence computations that may fail.
The bind (`>>=`) operation allows us to chain computations together, but with a special behavior when it comes to handling failures.

The behavior of the bind operation for `Either` values is that it only continues execution when there's no error:
when we bind with a function to an `Either` value using `>>=`, the function will only be applied if the `Either` value is a `Right` (success) value.
If the `Either` value is a `Left` (failure) value, the bind operation will short-circuit and return the `Left` value without applying the function.

This property is powerful because it allows us to write code that looks like a sequence of computations, but automatically handles failures along the way. If any computation in the sequence fails (i.e., returns a `Left` value), the subsequent computations will be skipped, and the failure will be propagated to the final result.

Let's look at an example to illustrate this property:

In [3]:
f :: Int -> Int -> Int -> Either ArithException Int
f a b c = do
  d <- safeDiv a b
  safeDiv d c

In this example, we have the `safeDiv` function from before, which performs division and returns an `Either` value.
We build this `f` function that takes three integers `a`, `b`, `c`, and computes a result based on dividing `a` by `b`,
then dividing the result by `c`.

The `f` function uses the `do` notation, which is a syntactic sugar for monadic operations, as we've learned in previous tutorials.

Now, let's see how the short-circuiting property comes into play:

In [4]:
-- Failure in the first division
f minBound (-1) 0 -- Left arithmetic overflow

-- Failure in the second division
f 2 (-1) 0        -- Left divide by zero

-- Successful computation
f 16 3 2          -- Right 2

Left arithmetic overflow

Left divide by zero

Right 2

In the first case, `f minBound (-1) 0` fails because the first division (`INT_MIN` divided by `-1`) results in a `Left` value.
The bind operation short-circuits and returns the `Left Overflow` value without proceeding to the second division.

In the second case, `f 2 (-1) 0` fails because the second division results in a `Left` value.
This will cause the function to return the failure as the result.

The third case succeeds because both divisions are valid. The bind operation applies the division to the `Right` values, and the final result is `Right 2`.

Why is this property useful? Recall that in our expression synthesizer, we defined the `Value` type as follows (extended with some more errors that may be used in this tutorial):

In [5]:
data Value
  = IValue SymInteger
  | BValue SymBool
  | BadValue
  | DivisionError
  deriving (Show)

To deal with the `Value` type, we may need to pattern-match on the results to see whether it is a failure, and manually short-circuit the computation:

In [6]:
add :: Value -> Value -> Value
add (IValue l) (IValue r) = IValue $ l + r
add DivisionError _ = DivisionError
add BadValue _ = BadValue
add _ DivisionError = DivisionError
add _ _ = BadValue

add DivisionError (IValue 1)       -- DivisionError
add (IValue 2) (IValue 3)          -- IValue 5
add (IValue 2) (BValue (con True)) -- BadValue

DivisionError

IValue 5

BadValue

If we define the errors as a dedicated type and use them as the `Left` value, we can then implement the `add` function cleanly:

In [7]:
data Error
  = BadValue
  | DivisionError

data Value
  = IValue SymInteger
  | BValue SymBool

add :: Value -> Value -> Either Error Value
add (IValue l) (IValue r) = return $ IValue $ l + r
add _ _ = throwError BadValue

By using `Either`, we can separate the concerns of successful and failure cases, making our code more modular and easier to reason about. This approach also aligns well with the concept with verification conditions, as we can capture the different paths separately and reason about them by asking whether a specific result or error can or cannot happen. We will demonstrate this in the verifier at the end of this tutorial.

## Stateful Computation

In addition to handling failure paths, another common scenario in programming is dealing with stateful computations.
In Haskell, stateful computations are typically modeled using the `State` monad.

The `State` is defined as follows:

```haskell
newtype State s a = State { runState :: s -> (a, s) }
```

Here, `s` represents the type of the state, and `a` represents the result type of the computation.
A stateful computation is a computation that computes the result from an initial state, and returns the new state along with the result.

The `State` monad provides two primary operations:

- `get :: State s s`: Retrieves the current state.
- `put :: s -> State s ()`: Sets the current state to a new value.

These operations allow us to read from and modify the state within a computation. Here's a simple example that demonstrates the usage of the `State` monad:

In [8]:
f :: Int -> State Int (Int, Int)
f v = do
  a1 <- get
  put $ v + a1
  a2 <- get
  return (a1, a2)

initialState = 2
runState (f 100) initialState -- ((2,102),102)

((2,102),102)

In this example, the `f` function uses the `State` monad to manage an `Int` state. It retrieves the current state using `get`, sets the state to the old state plus the argument, and then retrieves the new state.

For more details on the state monad, see https://en.wikibooks.org/wiki/Haskell/Understanding_monads/State.

## Combining Monads

In the previous sections, we explored the `Either` monad for handling failure paths and the `State` monad for managing stateful computations. However, in real-world scenarios, we often need to combine multiple monadic effects within a single computation. This is where monad transformers come into play.

Monad transformers allow us to combine the effects of multiple monads in a modular way. They provide a way to stack monads on each other, creating a new monad incorporating the effects of all the individual monads.

For example, the following are monad transformers for `Either` and `State`:

- `ExceptT`: Adds exception handling to a monad.
- `StateT`: Adds state management to a monad.

To use a monad transformer, we stack it on top of an existing monad. For example, to combine the effects of `Either` and `State`, we can use the `StateT` transformer with `Either` as the underlying monad:

In [9]:
type MyMonad = StateT Int (Either ArithException)

In this example, `MyMonad` is a monad that incorporates both state management and exception handling. The `Int` represents the state type, and `ArithException` represents the error type. The `throwError`, `return`, `get` operations all work on this monad, thanks to the `MonadError` and `MonadState` monad classes provided by the `mtl` library.
With these monad classes, we can write functions that work with any monad that implements the required class instances. This allows us to write more generic and reusable code.

Below, we extend the `safeDiv` function to work with monads that implement the `MonadError` class and implement the `f` function that performs division with the left-hand side from function arguments and the right-hand side from the state.

In [10]:
safeDiv :: (MonadError ArithException m) => Int -> Int -> m Int
safeDiv x 0 = throwError DivideByZero
safeDiv x (-1) | x == minBound = throwError Overflow
safeDiv x y = return (x `div` y)

f :: Int -> MyMonad Int
f lhs = do
  rhs <- get
  safeDiv lhs rhs

runStateT (f 8) 3 -- Right (2, 3)
runStateT (f 8) 0 -- Left divide by zero.

Right (2,3)

Left divide by zero

In the `f` function, we use `get` to retrieve the current state value, which represents the right-hand side of the division.
We then pass the left-hand side (`lhs`) and the right-hand side (`rhs`) to the `safeDiv` function,
which performs the division and returns the result (or exception) in the `MyMonad` context.

The `runStateT` function is used to run the stateful computation with an initial state value. It returns an `Either` value that represents either the successful result or an error.

In the next section, we will apply these concepts to our `Union` monad, enabling symbolic evaluation with both error handling and state management. 

## Extend our Expression Reasoning Tool

Now that we've seen how to handle failures and states using monads and monad transformers, let's apply these concepts to extend our expression reasoning tool.
We'll create a language with effectful statements for reading from and writing to a global state, while also handling potential failures like division by zero. Monad transformers will help us manage the complexity of combining errors and state.

First, let's enhance our expression type by adding two new constructors:

- `Var`: Looks up variables in the global environment, raising an undefined variable error if the variable is not found.
- `Div`: Performs division, raising a division error if the divisor is zero.

We will also define a custom `Error` type to represent the various errors that can occur in our language, such as `TypeMismatchError` for mismatched types and `NoStatementError` for empty programs.

In [11]:
:r -- Clear everything we've defined before. We will use Grisette's safeDiv which works for symbolic integers.
type VarId = T.Text

data Expr
  = I SymInteger
  | B SymBool
  | Var VarId
  | Add (Union Expr) (Union Expr)
  | Mul (Union Expr) (Union Expr)
  | Div (Union Expr) (Union Expr)
  | Eq (Union Expr) (Union Expr)
  deriving (Show, Eq, Generic)
  deriving (Mergeable, ExtractSym, EvalSym) via (Default Expr)

data Error
  = UndefinedVarError VarId
  | TypeMismatchError
  | DivisionError
  | NoStatementError
  deriving (Show, Eq, Generic)
  deriving (Mergeable, SymEq) via (Default Error)

data Value
  = IValue SymInteger
  | BValue SymBool
  deriving (Show, Eq, Generic)
  deriving (Mergeable, SymEq) via (Default Value)

To keep things simple, we'll only include assignment statements in our language.
An assignment statement evaluates an expression and stores the result in the global environment.
The program's overall result is determined by the value of the expression in the final assignment.

In [12]:
data Stmt = Assign VarId (Union Expr)

type Prog = [Stmt]

Next, we'll implement the semantics of our language by writing an interpreter. The first step is to define operations for querying and updating the environment, which we'll represent as an association list mapping (symbolic) variable IDs to (symbolic) values.

To look up a variable in the environment, we compare the given variable ID to the IDs in the list, starting from the head. If we find a pair with a matching ID (under symbolic equivalence), we return the associated value. If no match is found after going through the entire list, we throw an `UndefinedVarError`.

We'll use the `ExceptT Error` transformer to add error-throwing and handling capabilities to our `Union` monad. This provides us with `mrgThrowError` for throwing errors while still allowing multipath execution with `mrgIf`. (Remember that we always want to use the `mrg*` variants).

There is also a call to `mrgLift`. It is a function that "lifts" a computation from a monad (like `Union Value`) into a transformed monad (like `ExceptT Error Union Value`).

In [13]:
type Env = [(Union VarId, Union Value)]

getVar :: VarId -> Env -> ExceptT Error Union Value
getVar varId [] = mrgThrowError $ UndefinedVarError varId
getVar varId ((i,v):ivs) = mrgIf (return varId .== i) (mrgLift v) (getVar varId ivs)

sampleEnv = [
  (mrgIf "cond1" (return "a") (return "b"), mrgReturn $ IValue "i"),
  (mrgIf "cond2" (return "a") (return "b"), mrgReturn $ BValue "b")
  ]
getVar "a" sampleEnv

ExceptT {If (! (|| cond1 cond2)) (Left (UndefinedVarError "a")) (If cond1 (Right (IValue i)) (Right (BValue b)))}

To add a new binding to the environment, we simply prepend the variable ID and value pair to the association list:

In [14]:
setVar :: VarId -> Value -> Env -> Env
setVar varId value = ((mrgReturn varId, mrgReturn value):)

setVar "a" (IValue 1) sampleEnv

[({"a"},{IValue 1}),({If cond1 "a" "b"},{IValue i}),({If cond2 "a" "b"},{BValue b})]

Our interpreter's execution context needs to keep track of both the environment state and any exceptions that might occur.
We'll represent this using the `StateT` and `ExceptT` monad transformers:

In [15]:
type Context = StateT Env (ExceptT Error Union)

Here, we are using a new operation `mrgModifyError` because grisette's `safeDiv` uses `ArithException` as the exception type,and we want to transform that to our `Error` type.

With the tools in place, we can define the evaluation logic:

In [16]:
binOp :: Union Expr -> Union Expr -> ((Value, Value) -> Context Value) -> Context Value
binOp l r f = do
  el <- onUnion eval l
  er <- eval .# r
  f (el, er)

eval :: Expr -> Context Value
eval (I i) = mrgReturn $ IValue i
eval (B b) = mrgReturn $ BValue b
eval (Var varId) = do
  env <- get
  mrgLift $ getVar varId env
eval (Add l r) = binOp l r $ \case
  (IValue il, IValue ir) -> mrgReturn $ IValue $ il + ir
  _ -> mrgThrowError TypeMismatchError
eval (Mul l r) = binOp l r $ \case
  (IValue il, IValue ir) -> mrgReturn $ IValue $ il * ir
  _ -> mrgThrowError TypeMismatchError
eval (Div l r) = binOp l r $ \case
  (IValue il, IValue ir) -> do
    res <- mrgModifyError (\(e :: ArithException) -> DivisionError) $ safeDiv il ir
    mrgReturn $ IValue res
  _ -> mrgThrowError TypeMismatchError
eval (Eq l r) = binOp l r $ \case
  (IValue il, IValue ir) -> mrgReturn $ BValue $ il .== ir
  (BValue il, BValue ir) -> mrgReturn $ BValue $ il .== ir
  _ -> mrgThrowError TypeMismatchError

Evaluating statements and programs is fairly straightforward (`mrgModify` modifies the state with a function):

In [17]:
evalStmt :: Stmt -> Context Value
evalStmt (Assign varId expr) = do
  value <- eval .# expr
  mrgModify $ setVar varId value
  mrgReturn value

evalProg :: Prog -> Context Value
evalProg [] = mrgThrowError NoStatementError
evalProg [s] = evalStmt s
evalProg (s:ss) = do
  evalStmt s
  evalProg ss

To make it easier to work with our language, let's define some helper functions for constructing `Union Expr` values.

These functions can all be generated with template haskell by calling `mkMergeConstructor` at the top level:

```haskell
data Expr
  = ...

mkMergeConstructor "mrg" ''Expr
```

However, this isn't available in IHaskell environment, so we just write them down here.

In [18]:
mrgAdd :: (MonadUnion m) => Union Expr -> Union Expr -> m Expr
mrgAdd l r = mrgReturn $ Add l r

mrgMul :: (MonadUnion m) => Union Expr -> Union Expr -> m Expr
mrgMul l r = mrgReturn $ Mul l r

mrgDiv :: (MonadUnion m) => Union Expr -> Union Expr -> m Expr
mrgDiv l r = mrgReturn $ Div l r

mrgVar :: (MonadUnion m) => VarId -> m Expr
mrgVar = mrgReturn . Var

mrgI :: (MonadUnion m) => SymInteger -> m Expr
mrgI = mrgReturn . I

To run a program, we need to provide an initial environment, which serves as the initial state for the `StateT` transformer:

In [19]:
runProg :: Env -> Prog -> ExceptT Error Union Value
runProg env prog = flip mrgEvalStateT env $ evalProg prog

runProg [(mrgReturn "input", mrgReturn $ IValue "input")] []

runProg [(mrgReturn "input", mrgReturn $ IValue "input")] [
  Assign "a" (mrgI 1),
  Assign "b" (mrgAdd (mrgVar "a") (mrgI 2)),
  Assign "c" (mrgMul (mrgVar "b") (mrgVar "input"))
  ]

runProg [(mrgReturn "input", mrgReturn $ IValue "input")] [
  Assign "res" (mrgDiv (mrgI "b") (mrgVar "input"))
  ]

runProg [] [
  Assign "res" (mrgDiv (mrgI "b") (mrgVar "input"))
  ]

ExceptT {Left NoStatementError}

ExceptT {Right (IValue (* 3 input))}

ExceptT {If (= input 0) (Left DivisionError) (Right (IValue (div b input)))}

ExceptT {Left (UndefinedVarError "input")}

One of the benefits of using a custom error type is that we can leverage the solver to find inputs that trigger specific exceptions. This is a common use case for symbolic execution, and Grisette makes it easy to define custom error types and create predicates over errors and values.

In the following example, we attempt to find an input `x` that causes a division error. We do this by converting the `ExceptT` result into a symbolic boolean formula, where `DivisionError` is mapped to `True`, and all other outcomes are mapped to `False`. The `simpleMerge` here converts a `Union SymBool` to `SymBool`. This instructs the solver to search for inputs that lead to a `DivisionError`.

In [20]:
prog :: Prog
prog = [
  Assign "a" (mrgAdd (mrgI 1) (mrgI "x")),
  Assign "res" (mrgDiv (mrgI "b") (mrgVar "a"))
  ]

solve (precise z3) $ simpleMerge $ do
  res <- runExceptT $ runProg [] prog
  case res of
    Left DivisionError -> mrgReturn $ con True
    _ -> mrgReturn $ con False

Right (Model {x -> -1 :: Integer})

Grisette provides a function called `solveExcept` that allows us to specify a predicate over the result of an `ExceptT` computation. The predicate should return symbolic `True` for the outcomes we're interested in and `False` for all other outcomes. `solveExcept` then uses the solver to find inputs that satisfy the predicate.

In the following example, we attempt to find an input `x` that causes a division error. We define a predicate that maps `DivisionError` to `True`, and all other outcomes to `False`. We then pass this predicate to `solveExcept`, along with our program and the desired solver (in this case, Z3). The solver will be able to conclude that our program will never yield an `DivisionError`.

In [21]:
expectedPath :: Either Error a -> SymBool
expectedPath (Left DivisionError) = con True
expectedPath _ = con False

prog1 :: Prog
prog1 = [
  Assign "a" (mrgAdd (mrgI 1) (mrgMul (mrgI "x") (mrgI "x"))),
  Assign "res" (mrgDiv (mrgI "b") (mrgVar "a"))
  ]

solveExcept (precise z3) expectedPath $ runProg [] prog1

Left Unsat

## Conclusion

In this tutorial, we've explored how to handle failure paths and stateful computations using monads and monad transformers in Grisette.
We learned how to use `Either` and `ExceptT` to represent and handle errors, and `State` and `StateT` to model stateful computations.

We extended our expression reasoning tool to support a more expressive language with variables, division, and assignment statements,
using monad transformers to combine error handling and state management in a modular and composable way.

Finally, we leveraged Grisette's `solveExcept` function to find inputs that trigger specific exceptions or satisfy predicates over the program's result,
enabling us to test and verify the correctness of our programs.

By mastering these techniques, you can create more expressive and verifiable languages using Grisette,
and effectively handle errors and state in your solver-aided applications. With the power of monads, monad transformers, and symbolic execution,
you can tackle complex problems and build reliable, correct software systems.