In [1]:
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
import GHC.Generics
import Grisette

# UnionM and Custom Data Types

## Introduction

In the [previous tutorial](./1_symbolic_type.ipynb), we discussed the basic usage of Grisette with symbolic types.
In this tutorial, we will discuss how to write complex programs and use data abstractions that seamlessly integrate with the Grisette system.

In this tutorial, you will learn how to:
- Utilize the `UnionM` monadic container for representing choices and path conditions
- Understand path merging and its importance for efficient symbolic execution
- Work with user-defined data types in Grisette
- Build a program synthesizer using CounterExample-Guided Inductive Synthesis (CEGIS)

By the end of this tutorial, you will have a solid understanding of how to leverage Grisette's features to build sophisticated solver-aided applications.

## UnionM Monadic Container

At the core of Grisette lies the `UnionM` monadic container. `UnionM` is used to represent choices and path conditions in a program. It allows you to wrap values and introduce branches based on symbolic conditions. The `UnionM` helps us integrate user-defined data types into the Grisette system.

To wrap a value in `UnionM`, with a path condition that is always true, you can use the `mrgReturn` function:

In [2]:
mrgReturn (1, "a") :: UnionM (SymInteger, SymBool)
mrgReturn (Left 10) :: UnionM (Either Integer SymBool)

{(1,a)}

{Left 10}

The `mrgReturn` function takes a value and wraps it in the `UnionM` container. It is similar to the `return` function in the standard `Monad` type class, but it also handles the merging of values, which we will discuss later.

To introduce path conditions, you can use the `mrgIf` function:

In [3]:
unionList :: UnionM [SymInteger]
unionList = mrgIf "cond" (mrgReturn ["a1"]) (mrgReturn ["a2", "a3"])
unionList

{If cond [a1] [a2,a3]}

In this example, `unionList` represent a choice between two lists of symbolic integers, depending on the value of the symbolic condition `cond`. If `cond` is true, the value of `unionList` will be `[a1]`, and if `cond` is false, the value will be `[a2,a3]`.

In [4]:
evaluateSym False (buildModel ("cond" ::= True, "a1" ::= (1 :: Integer))) unionList
evaluateSym False (buildModel ("cond" ::= False, "a2" ::= (1 :: Integer))) unionList

{[1]}

{[1,a3]}

Conceptually, a `UnionM` value represents an if-then-else tree, where each branch is associated with a path condition. The structure of the `unionList` can be visualized as follows:

```text
unionList:
            +------+
            | cond |
            +------+
       cond /      \ !cond
           /        \
       +------+ +---------+
       | [a1] | | [a2,a3] |
       +------+ +---------+
```

When working with `UnionM`, you can use the standard monadic operations, such as `>>=` (bind), or the `do`-notation syntactic sugar, to chain computations and manipulate the values inside the container. For example:

In [5]:
do l <- unionList
   mrgReturn $ 1 : l

{If cond [1,a1] [1,a2,a3]}

In this example, we use the `do` notation to bind the value of `unionList` to the variable `l`, then in each branch, we prepend `1` to the list.

### Monads and UnionM

For those not familiar with monads, a monad is a typeclass that defines two operations: `return` and `bind` (`>>=` in Haskell). In the context of `UnionM`:

`return` wraps a value in the `UnionM` container. Conceptually, it transforms a single value into a `UnionM` value with a true path condition.

```text
a:
        a

return a:
    +-------+
    |   a   |
    +-------+
```

`bind` allows you to chain operations on `UnionM` values, effectively splitting the execution paths based on the path conditions. It takes a `UnionM a` value and a function `a -> UnionM b`, and returns a `UnionM b` value. Conceptually, it applies the function to each branch of the input `UnionM` value, potentially introducing new branches.

In Haskell, the `do` notation is a syntactic sugar for the `bind` and `return` operations. It allows you to write monadic code in a more imperative style, making it easier to read and understand.

For example, the following code using do notation:

```haskell
do l1 <- mrgIf "cond" (return ["a"]) (return ["b","c"])
   mrgIf "cond2" (return $ "d" : l1) (return $ "e" : l1) :: UnionM [SymInteger]
```

is a syntactic sugar for

```haskell
bind
  (mrgIf "cond" (return ["a"]) (return ["b","c"]))
  (\l1 -> mrgIf "cond2" (return $ "d" : l1) (return $ "e" : l1))
```

In the `do` notation, the left arrow `<-` is used to bind the result of a monadic computation to a variable, which can be used in subsequent computations. The evaluation of the code example could be understood as follows:

```text
Step 1 (l1):
      +------+
      | cond |
      +------+
       /    \
      /      \
  +-----+ +-------+
  | [a] | | [b,c] |
  +-----+ +-------+
Step 2 (apply function in the leaves):
                    +------+
                    | cond |
                    +------+
                     /    \
               +----+      +-----+
              /                   \
         +-------+             +-------+       
         | cond2 |             | cond2 |       
         +-------+             +-------+       
           /   \                /     \        
          /     \              /       \       
    +-------+ +-------+ +---------+ +---------+
    | [d,a] | | [e,a] | | [d,b,c] | | [e,b,c] |
    +-------+ +-------+ +---------+ +---------+
Step 3 (merge, preview for the next section):
                    +------+
                    | cond |
                    +------+
                     /    \
              +-----+      +-----+
             /                    \
+---------------------+ +-----------------------+   
| [(ite cond2 d e),a] | | [(ite cond2 d e),b,c] |  
+---------------------+ +-----------------------+   
```

In [6]:
do l1 <- mrgIf "cond" (return ["a"]) (return ["b","c"])
   mrgIf "cond2" (return $ "d" : l1) (return $ "e" : l1) :: UnionM [SymInteger]

{If cond [(ite cond2 d e),a] [(ite cond2 d e),b,c]}

## Path Merging

Path explosion is a common problem in symbolic execution, where the number of paths grows exponentially with the number of branching conditions. This can lead to poor performance and scalability issues when dealing with complex programs.

To understand the path explosion problem, consider the following example:

```haskell
union = do
  a <- union1
  b <- union2
  c <- union3
  return (a, b, c)

furtherCode = do
  v <- union
  return $ g v
```

In this code, `union1`, `union2`, and `union3` are assumed to be `UnionM` values representing choices, and `furtherCode` applies another function `g` to the result of `union`.

If each of `union1`, `union2`, and `union3` has 2 branches, the total number of paths in `union` will be 2^3 = 8. This means that the function `g` will be executed 8 times, once for each possible combination of choices.

This path explosion problem can quickly become intractable as the number of branches and depth of the computation increase.

Grisette addresses this problem through path merging.
Instead of exploring each path separately, Grisette merges the paths and represents the result as a symbolic expression.
This allows for efficient symbolic execution without the need to explicitly enumerate all possible paths.

To enable path merging, you should use the `mrgReturn` function instead of the vanilla return function.
Grisette will then merge the branches when possible in the result.

In [7]:
do l <- unionList
   mrgReturn $ sum l

{(ite cond a1 (+ a2 a3))}

In the example, the sum from the two branches has the `SymInteger` type, which could be merged together with the SMT `ite` operator.
The `mrgReturn` function is the key to merging in Grisette.
If we use the vanilla `return` function or `fmap` function, the result will instead have two branches and will not be merged.
The angle brackets indicate that the result isn't merged.

In [8]:
sum <$> unionList

<If cond a1 (+ a2 a3)>

Let's examine the type signatures of `return` and `mrgReturn`.

In [9]:
:t return
:t mrgReturn

The `return` function has a simple type signature.
It takes a value of type `a` and wraps it in a monadic context `m`.

On the other hand, `mrgReturn` has additional constraints.
It requires the value to be an instance of the `Mergeable` type class, and the monad `m` to be an instance of the `MonadTryMerge` type class.

The `Mergeable` type class defines the merging strategy for different types. By deriving instances of `Mergeable` for your custom types, you can enable path merging and avoid the path explosion problem.

The `MonadTryMerge` type class is used to control the merging behavior of the monad. Those monads capable of merging will cache the merging strategy when calling `mrgReturn`. This will be used to merge the results of the whole `do`-block.

It's okay that you don't understand everything here.
The key takeaway is to always use `mrgReturn` instead of `return`, unless you know what you are doing and want to manually control where to merge the values.
We also provide `mrg*` (or sometimes, named `sym*`) variants for many combinators from GHC's base library, including those working with `Monad`, `Applicative`, `Functor`, `Foldable`, `Traversable` and lists.
You may want to check out the documentation.

In [10]:
return 1 :: UnionM Int
mrgReturn 1 :: UnionM Int

<1>

{1}

## Deriving Mergeable Instance

Different types have different merging strategies. For example, for the type `[SymInteger]`, the Grisette system will try to merge lists with the same lengths, while for the type `Either Integer SymBool`, left values will be merged if they are exactly the same, while right values will always be merged.

(It is okay to use `return` here because the `mrgIf` will handle the merging)

In [11]:
mrgIf "cond" (return ["a","b"]) (return ["c","d"]) :: UnionM [SymInteger]
mrgIf "cond" (return $ Left 1) (return $ Left 2) :: UnionM (Either Integer SymBool)
mrgIf "cond" (return $ Left 1) (return $ Left 1) :: UnionM (Either Integer SymBool)
mrgIf "cond" (return $ Left 1) (return $ Right "x") :: UnionM (Either Integer SymBool)
mrgIf "cond" (return $ Right "x") (return $ Right "y") :: UnionM (Either Integer SymBool)

{[(ite cond a c),(ite cond b d)]}

{If cond (Left 1) (Left 2)}

{Left 1}

{If cond (Left 1) (Right x)}

{Right (ite cond x y)}

As we mentioned earlier, merging is controlled by the `Mergeable` type class. This means that we can define `Mergeable` instances for our custom types to make them compatible with Grisette.

Defining a merging strategy is beyond the scope of this tutorial, but we have provided a default instance for non-GADT data types, which you can access using `DerivingVia` with a `Generic` instance. Automatic deriving instances for GADTs will require `TemplateHaskell`, and we may provide this in a future release.

In [12]:
data A = X Int SymBool | Y SymInteger
  deriving (Show, Generic)
  deriving (Mergeable) via (Default A)

mrgIf "cond" (return $ X 1 "a") (mrgIf "cond2" (return $ X 2 "b") (return $ X 1 "c")) :: UnionM A

{If (|| cond (! cond2)) (X 1 (ite cond a c)) (X 2 b)}

Note that a similar mechanism is provided for most of the type classes to make the types fully compatible with Grisette.

In [13]:
deriving via (Default A) instance (EvaluateSym A)
evaluateSym False (buildModel ("a" ::= True)) $ X 1 "a"

X 1 true

## A Rewriting Rule Synthesizer

### Overview
Let's extend the expression equivalence verifier to build a simple rewriting rule synthesizer. The goal is to find an alternative expression that implements the same functionality as the original expression.

We introduce the concept of *program sketches* and *holes*. A *program sketch* represents a *program space* with *holes* that can be instantiated to generate different expressions. *Holes* serve as placeholders for constants. A synthesizer then searches the program space by searching for an instantiation of the holes.

For example, consider the following program sketch for our problem:

```haskell
(If hole1 (Add x hole2) (Mul x hole3)
```

This program sketch represents a program space that includes, but not limited to, the following programs:

- `x + 1` (with `hole1` to be true and `hole2` to be 1)
- `x + 2` (with `hole1` to be true and `hole2` to be 2)
- `x * 3` (with `hole1` to be false and `hole3` to be 3)

Our synthesis goal can be formulated as the following formula:

$p_1 = \exists~\mathrm{hole}\in \mathrm{consts}(e_\mathrm{sketch})\setminus \mathrm{consts}(e_\mathrm{orig}). \forall~\mathrm{var}\in e_\mathrm{orig}. \mathrm{eval}(e_\mathrm{orig}) = \mathrm{eval}(e_\mathrm{sketch})$

Note that all symbolic constants already present in the original expression are not treated as holes, and our synthesized expression should have the same semantics as the original expression, regardless of the values assigned to these variables.

If you are using Grisette to build a program synthesizer, you may want to check out https://github.com/lsrcz/grisette-synth-lib, which provides more efficient encoding and an easier-to-use interface focused on synthesis.

### The DSL

To represent a program sketch, we need to extend our expression type. We modify the operands of `Add`, `Mul` and `Eq` to `UnionM Expr` to represent a choice among multiple programs. In this tutorial, we will not use GADTs, as we want to simply derive the necessary instances for the type.

In [14]:
data Expr
  = I SymInteger
  | B SymBool
  | Add (UnionM Expr) (UnionM Expr)
  | Mul (UnionM Expr) (UnionM Expr)
  | Eq (UnionM Expr) (UnionM Expr)
  deriving (Show, Eq, Generic)
  deriving (Mergeable, ExtractSymbolics, EvaluateSym) via (Default Expr)

Additionally, we introduce a sum type `Value` to represent all possible evaluation results.
The evaluation result can be a symbolic integer, a symbolic boolean, or a special `BadValue` if an ill-typed expression is evaluated.

In [15]:
data Value
  = IValue SymInteger
  | BValue SymBool
  | BadValue
  deriving (Show, Eq, Generic)
  deriving (Mergeable, SEq) via (Default Value)

The `eval` function is updated to handle `UnionM Expr` and perform dynamic type checking.
The type of our `eval` function is `Expr -> UnionM Value`, which means that we interpret an expression, and the result is a choice among the value types. To evaluate `UnionM Expr`, we can use `onUnion` or `.#` combinator provided by Grisette to lift the `eval` function.

In [16]:
binOp :: UnionM Expr -> UnionM Expr -> ((Value, Value) -> Value) -> UnionM Value
binOp l r f = do
  el <- onUnion eval l
  er <- eval .# r
  mrgReturn $ f (el, er)

eval :: Expr -> UnionM Value
eval (I i) = mrgReturn $ IValue i
eval (B b) = mrgReturn $ BValue b
eval (Add l r) = binOp l r $ \case
  (IValue il, IValue ir) -> IValue $ il + ir
  _ -> BadValue
eval (Mul l r) = binOp l r $ \case
  (IValue il, IValue ir) -> IValue $ il * ir
  _ -> BadValue
eval (Eq l r) = binOp l r $ \case
  (IValue il, IValue ir) -> BValue $ il .== ir
  (BValue il, BValue ir) -> BValue $ il .== ir
  _ -> BadValue

Now we can define some sketches and evaluate them. Our `eval` function is capable of simultaneously evaluating the entire program space!

In [17]:
sketch :: Expr
sketch = Add (mrgReturn $ I "a") (mrgReturn $ I "b")
sketch
eval sketch

sketch :: UnionM Expr
sketch = do
  let a = mrgReturn $ I "a"
  let b = mrgReturn $ I "b"
  mrgIf "c" (mrgReturn $ Add a b) (mrgReturn $ Mul a b)
sketch
eval .# sketch

Add {I a} {I b}

{IValue (+ a b)}

{If c (Add {I a} {I b}) (Mul {I a} {I b})}

{IValue (ite c (+ a b) (* a b))}

### The Synthesizer

Finally, let's write the synthesizer. Recall that in our formulation, we have an exists-forall formula. This cannot be directly transformed into an existential formula, so we use the CounterExample-Guided Inductive Synthesis (CEGIS) algorithm to handle it by making multiple solver calls, each of which is an existential formula.

The semantics of `cegisForAll` is to solve the following formula:

$\exists P.(\exists I. \mathrm{pre}(P, I))\wedge(\forall I.\mathrm{pre}(P, I)\Rightarrow\mathrm{post}(P, I))$

You can view $P$ as the space of the program and $I$ as the inputs to the program. In our synthesizer, $P$ represents all the holes, and $I$ represents all the variables that exist in the original expression.

The `cegisForAll` function takes three arguments. The first is the configuration of the solver. The second argument controls which symbolic constants are in $I$. With the `ExtractSymbolics` instance, we extract all the symbolic constants in the original expression and use them as the set $I$. The third argument specifies the preconditions and postconditions. Here, our precondition is simply true, so we can omit it and use the convenient function `cegisPostCond`.

In [18]:
synthesisRewriteTarget :: Expr -> UnionM Expr -> IO ()
synthesisRewriteTarget expr sketch = do
  let lhs = eval expr
  let rhs = eval .# sketch
  r <- cegisForAll (precise z3) expr $ cegisPostCond $ lhs .== rhs
  case r of
    (_, CEGISSuccess model) -> do
      putStrLn "Successfully synthesized RHS:"
      print $ evaluateSym False model sketch
    (cex, failure) -> do
      putStrLn $ "Synthesis failed with error: " ++ show failure
      putStrLn $ "Counter example list: " ++ show cex

We can now synthesize an expression. The following example tries to determine whether we can rewrite $2 * x$ as $x + x$ or $x * x$.

In [19]:
x :: UnionM Expr
x = mrgReturn $ I "x"

lhs :: Expr
lhs = Mul (mrgReturn $ I 2) x

sketch :: UnionM Expr
sketch =
  mrgIf "c"
    (mrgReturn $ Add x x)
    (mrgReturn $ Mul x x)
synthesisRewriteTarget lhs sketch

Successfully synthesized RHS:
{Add {I x} {I x}}

The next example uses a larger sketch. We want to see whether $(a * b) + (b * c)$ can be rewritten.

The sketch we are using is:

```
(?{a,b,c} ?{+,*} ?{a,b,c}) ?{+,*} ?{a,b,c}
```

The question mark indicates that we are selecting among the choices, which is why we call the operator that performs this selection `choose`.

In [20]:
a, b, c :: UnionM Expr
a = mrgReturn $ I "a"
b = mrgReturn $ I "b"
c = mrgReturn $ I "c"
lhs :: Expr
lhs = Add (mrgReturn $ Mul a b) (mrgReturn $ Mul b c)

sketch :: UnionM Expr
sketch = do
  let lhs1 = chooseUnion [a, b, c] "lhs1"
  let rhs1 = chooseUnion [a, b, c] "rhs1"
  let rhs = chooseUnion [a, b, c] "rhs"
  let lhs = choose [Add lhs1 rhs1, Mul lhs1 rhs1] "lhs"
  choose [Add lhs rhs, Mul lhs rhs] "sketch"

synthesisRewriteTarget lhs sketch

Successfully synthesized RHS:
{Mul {Add {I a} {I c}} {I b}}

## Conclusion

In this tutorial, we explored the core construct of Grisette, the `UnionM` monadic container. We learned how to work with `UnionM` to represent choices and introduce path conditions, and how Grisette merges execution paths to improve efficiency and avoid the path explosion problem.

We also discovered how to derive instances of the `Mergeable` typeclass for our custom data types, enabling seamless integration with Grisette's features. By deriving instances using `Generic` and `DerivingVia`, we can quickly make our types compatible with Grisette without manually writing the instances.

Furthermore, we extended our expression equivalence verifier to build a simple rewriting rule synthesizer. We introduced the concept of program sketches and holes, and demonstrated how to use CounterExample-Guided Inductive Synthesis (CEGIS) provided by Grisette to search for a program that satisfies a given specification.

However, there are still some areas where the code can be improved. For example, we had to introduce the `BadValue` constructor to handle ill-typed expressions, which adds boilerplate to our code. Additionally, constructing sketches can be verbose and difficult to reuse.

In future tutorials, we will explore more advanced features of Grisette that can help us address these issues and simplify our code. We will learn about constructs that can reduce boilerplate and make our code more concise and reusable.

By mastering the concepts introduced in this tutorial and leveraging the power of Grisette, you will be well-equipped to tackle complex problems involving symbolic execution, program synthesis, and verification.