# A weekend replication of [STOKE](http://stoke.stanford.edu/), a stochastic superoptimiser

A super optimizer is a type of compiler that takes a program and finds
the _optimal_ sequence of instructions that perform the same task as
the original program.

A _stochastic_ thing `X` is a thing `X` that uses randomness.

STOKE is a superoptimizer that finds optimal instruction sequences
to perform a particular task by using numerical techniques which
rely on randomness (markov-chain-monte-carlo/MCMC) to explore the
space of "all possible programs".

Here, we re-implement `STOKE` for a tiny stack machine language whose instruction
set is `Push`, `Add`, `Mul`, and `And`.

### Sample output

```
*** original: (nparams: 0 | [IPush 2,IPush 3,IAdd])***
[IPush 5] | score: 2.5 // constant folding: 2 + 3 -> 5

*** original: (nparams: 1 | [IPush 2,IMul])***
[IDup,IAdd] | score: 2.25 // strength reduction: 2 * x -> x + x

*** original: (nparams: 1 | progInsts = [IDup,IAnd])***
[] | score: 3.0 // strength folding: x & x == x
```

That is, we automatically disover those program transformations by randomly trying different
programs (in a somewhat smart way).

### High level algorithm

What we do is to "MCMC-sample the space of possible programs, with a
scoring function based on program equivalence and performance". Broken
down, the steps are:

- Start with the original program `p`
- Perturb it slightly to get a program `q` (add an instruction, remove some instructions, change an instruction)
- Assign a score to `q` by sending `p` and `q` random inputs, and checking
  how often they answer the same.
- If they answered the same for _all_ random inputs, ask an SMT solver
  nicely if `p` and `q` are _equivalent_. If she's in a good mood and the
  universe is kind, the answer is yes.
- Now, score `q` based on the factors of:
    1. Are `p` and `q` semantically equivalent?
    2. On how many inputs `p` and `q` had the same answer?
    3. is `q` faster than `p`?
- Now, either pick `q` to be the new `p` or stay at `p` depending on `q`'s
  score.
- Repeat ~10,000 times.
- Order best `q`s visited.


The results are quite impressive: Even with code as naive as what I've written,
we are able to regain peephole optimisations that compilers perform essentially
"for free".


As usual, since we code in haskell, we begin with incantations to GHC:

In [5]:
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}

import GHC.Stack
import Data.Maybe (catMaybes)
import Data.List (nub, sortOn, sortBy)
import Data.Bits
import Data.SBV
import Data.SBV.Internals (CV)
import Data.Word
import Control.Monad
import System.Random
import Control.Applicative
import Control.Monad.State
import Control.Monad.Fail
import Debug.Trace

Next, we build some primitives to get our hands on randomness: `randint`, `randint8`, `randfloat`, and `randbool`.

In [6]:
-- | provide a random integer in [lo, hi]
randint :: (Int, Int) -> IO Int
randint (lo, hi) = liftIO $ getStdRandom $ randomR (lo, hi)

-- | random uniform int8 (any int8)
randint8 :: IO Int8
randint8 = liftIO $ getStdRandom $ randomR (-128, 127)

-- | random uniform float in (lo, hi)
randfloat :: (Float, Float) -> IO Float
randfloat (lo, hi) = liftIO $ getStdRandom $ randomR (lo, hi)

-- | random boolean
randbool :: IO Bool
randbool = getStdRandom $ random

#### Introduction to stack machines

We now introduce our instruction set, and notion of a program, and we show how to execute a program
to understand our instruction set. Our language is going to a stack based language.

At the beginning of the program, we assume that the parameters are present on the stack.

We have instructions to push values onto the stack, and any operations like `Add` or `Mul`
will always operate on values that are at the top of the stack.

At the end of the program, we assume that we only one value left on the stack,
the output of the program.


#### An example

To implement `\x -> x + 1`, This will be the sequence of instructions `[IPush 1, IAdd]`.

This corresponds to the following sequence of events:
- Initially, the stack contains the parameter to our program `[x] <- top of stack`
- Next, we execute `Push 1` which pushes a `1` to the top: `[x, 1] <- top of stack`
- Next, we execute `Add` which pops `x` and `1` from the top of the stack, and pushes down `x + 1`.
  So, our stack will be `[x + 1] <- top of stack`.
- The computation has ended, since we have run all instructions, and we have a single value left on the stack,
  `x + 1`.
  
#### Some more quick examples:

- `\x -> (x + 1) * 2` will be the instructions `[Push 1, Add, Push 2, Mul]`.
- `\x -> x + x` will be the instructions `[Dup, Add]` where `Dup` duplicates the value at the top of the stack.
- `\x -> x` will be the instructions `[]`. We don't need to do anything, since the program starts with the parameter on the stack.

#### Real-world examples

the java virtual machine is stack-based. Their instruction set can be [found here](https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-6.html). It is far more extensive
since it has support for arrays, method invocations, and whatnot. However, the _idea_ is
exactly the same.

#### Our instruction set

`Inst` is our datatype which describes instructions. `Program` contains a sequence of instructions `progInsts`, and
the number of parameters this program requires, `progNParams`.

`interpInst` interprets a single instruction, returning a `Just` if it is successful, and `Nothing`
otherwise.

`interpInsts` interprets a sequence of instructions, returning a `Just a`
of the final value on top of the stack after the program is executed, and `Nothing` otherwise. 

In [35]:
data Inst = 
  IPush Int8 --- ^ Push a value onto the stack
  | IAdd -- ^ Add the top two values on of the stack 
  | IMul -- ^ Multiply the top two values of the stack
  | IDup -- ^ Duplicate the value on top of the stack
  | IAnd -- ^ Take the bitwise AND of the top two  values of the stack
  | ISwap -- ^ Swap the top two values of the stack
  deriving(Eq, Show, Ord)

data Program = Program { progNParams :: Int, progInsts :: [Inst] }
  deriving (Eq, Ord)

instance Show Program where
  show (Program nparams insts) =
     "(" <> "nparams: " <> show nparams <> " | " <> show insts <> ")"

-- | Given input stack and instruction, return output stack
interpInst :: Num a => Bits a =>  [a] -> Inst -> Maybe [a]
interpInst as (IPush x) = Just $ (fromIntegral x):as
interpInst (a:a':as) (IAdd) = Just $ (a+a':as)
interpInst (a:a':as) (IMul) = Just $ (a*a':as)
interpInst (a:as) (IDup) = Just $ (a:a:as)
interpInst (a:a':as)(IAnd) = Just $ (a .&. a':as)
interpInst (a:a':as) (ISwap) = Just $ (a':a:as)
interpInst _ _ = Nothing

-- | Given input stack and a sequence of instructions in a program,
-- run the sequence of instructions and return the output stack. 
interpInsts :: Num a => Bits a => [Inst] -> [a] -> Maybe a
interpInsts insts as =
  case foldM interpInst as insts of
    Just [a] -> Just a
    _ -> Nothing

-- | Given program and stack, compute output
interpProgram :: Num a => Bits a => Program -> [a] -> Maybe a
interpProgram Program{..} as = interpInsts progInsts as

-- | compute 1 + 2
p1plus3 :: Program
p1plus3 = Program 0 [IPush 1, IPush 2, IAdd]

putStrLn $ "p1plus3: " <> show p1plus3 <> " | " <> show (interpProgram p1plus3 []) 

-- | compute \x -> x * 2 with x = 10
pmul2 :: Program
pmul2 = Program 1 [IPush 2, IMul]

putStrLn $ "pmul2: " <> show pmul2 <> " | " <> show (interpProgram pmul2 [10]) 
putStrLn $ "pmul2 with too much on the stack: " <> show pmul2 <> " | " <> show (interpProgram pmul2 []) 

p1plus3: (nparams: 0 | [IPush 1,IPush 2,IAdd]) | Just 3

pmul2: (nparams: 1 | [IPush 2,IMul]) | Just 20

pmul2 with too much on the stack: (nparams: 1 | [IPush 2,IMul]) | Nothing

#### Cost model for our instruction set

We introduce a simplistic cost model, which charges multiplication `4`, other operations `1`.
We have `costInst` that calculates the cost of a single `Inst`, and `costProgram` which sums up
the costs of the instructions of the program.

We will refer back to this when we are assigining a score to a program. A program with higher costs
will be penalized, since we want to _optimise_ a program: we are looking for a program with the lowest cost

In [89]:
-- | Cost per instruction
costInst :: Inst -> Float
costInst (IPush _) = 1
costInst IAdd = 1
costInst IMul = 4
costInst IDup = 1
costInst IAnd = 1
costInst ISwap = 1

-- | Sum of costs of instructions in the program
costProgram :: Program -> Float
costProgram p = sum $ map costInst $ progInsts p

putStrLn $ "mul2: " <> show pmul2 <> " | cost: " <> show (costProgram pmul2)
putStrLn $ "1plus3: " <> show p1plus3 <> " | cost: " <> show (costProgram p1plus3)

mul2: (nparams: 1 | [IPush 2,IMul]) | cost: 5.0

1plus3: (nparams: 0 | [IPush 1,IPush 2,IAdd]) | cost: 3.0

#### Randomly perturbing programs

Next, we need a way to _edit_ a given program to get a program similar to it. We want to gradually
edit a program, to explore the "space of all possible programs". For this, we will build functions to:

- Generate a random instruction: `randInst`
- Add an instruction to a program: `addListElem`
- Remove a section of instructions from a program: `dropListElems`
- Edit an instruction in a program: `replaceListElem`

To edit a program, we will apply one of these choices randomly: `perturbProgram`

In [43]:
-- | generate a random instruction.
randInst :: IO Inst
randInst = do
  r <- randint (1, 6)
  case r of
    1 -> IPush <$> randint8
    2 -> pure $ IAdd
    3 -> pure $ IMul
    4 -> pure $ IDup
    5 -> pure $ IAnd
    6 -> pure $ ISwap


-- | drop a list element at the specified indeces (inclusive)
dropListElems :: [a] -> Int -> Int -> [a]
dropListElems as ixbegin ixend = take ixbegin as ++ drop (ixend + 1) as

-- | replace a list element at the specified index
replaceListElem :: [a] -> Int -> a -> [a]
replaceListElem as ix a = take ix as ++ [a] ++ drop (ix+1) as

-- | add to a list *after* the specified index
addListElem :: [a] -> Int -> a -> [a]
addListElem as ix a = take ix as ++ [a] ++ drop ix as


-- | Edit the program by a single instruction. Add, modify, or delete
-- an instruction.
perturbProgram :: Program -> IO Program
perturbProgram Program{..} = do
  r <- randint (1, 3) -- ^ pick a random choice
  ix <- randint (0, length progInsts - 1) -- ^ pick a random index
  ix' <- (ix +) <$> randint (0, length progInsts - 1) -- ^ and another one
  progInsts <- case r of
                 1 -> pure $ dropListElems progInsts ix ix'
                 2 -> replaceListElem progInsts ix <$> randInst
                 3 -> addListElem progInsts ix <$> randInst

  return $ Program{..}
  
putStrLn $ "original pmul2: " <> show pmul2
replicateM_ 3 $ perturbProgram pmul2 >>= \p -> putStrLn $ "perturbed pmul2: " <> show p

original pmul2: (nparams: 1 | [IPush 2,IMul])

perturbed pmul2: (nparams: 1 | [IMul])
perturbed pmul2: (nparams: 1 | [IPush 78,IMul])
perturbed pmul2: (nparams: 1 | [IPush 2,IAdd,IMul])

#### Optimising using perturbed programs

Sweet. Now, we need methods to decide which perturbed programs to keep and which to drop. For this, we have two criteria:
- **Correctness**: Does the original progam and the new perturbed program _behave_ the same way? 
- **Performance**: Does the perturbed program run _faster_ than the original program?

We answer the correctness question in two ways: First, we run both programs on a small number of random inputs. If they
give different answers on these random inputs, then we are sure that the programs are different. If they answer the
same, then we need to now create a _proof_ that they are equal. Remember, compilers are not allowed to change the
_meaning_ of a program!

For performance, we build a _cost model_: A mapping of instructions to how much time they take, and we optimise
on this cost model --- we try to reduce the total time our program takes.

Let's do correctness first, since it's a little more involved, and will require us to use `SBV`, a
neat haskell library that allows us to communicate with the `Z3` solver.

#### Proving equivalence part 1: Running two programs on the same inputs

We implement `proprtionAgreeingRuns`, which given two progams, gives us the number of runs on which
they provided the same output. It's a surprisingly effective way to weed out programs that do different things.

We first check that both programs take the same number of inputs. If they don't we immediately return a `0`.
Once we know that they take the same number of inputs, we generate random inputs `10` times and compare their outputs. We return the proportion of times their outputs were the same.

Indeed, this kind of checking of "using random inputs to check if two programs are the same" is a big deal in complexity
theory. It's a subfield unto itself called [property testing](https://en.wikipedia.org/wiki/Property_testing), and 
we have proofs that this kind of checking is remarkably effective in many constrained situations.

In [57]:
proportionAgreeingRuns :: Program -> Program -> IO Float
proportionAgreeingRuns p1 p2 = do
  if progNParams p1 /= progNParams p2
  then return 0
  else do
    let nruns = 10
    scores <- replicateM nruns $ do
      ps <- replicateM (progNParams p1) randint8
      let l = interpProgram p1 ps
      let r = interpProgram p2 ps
      return $ if l == r then 1 else 0
    return $ fromIntegral (sum scores) / fromIntegral nruns

-- | Test between mul2 and it's perturbed variant
replicateM_ 3 $ do 
  pmul2' <- perturbProgram pmul2
  putStrLn $ "*** mul2: " <> show pmul2 <> " | mul2': " <> show pmul2'
  prop <- proportionAgreeingRuns pmul2 pmul2'
  putStrLn $ "    proportion of agreeing runs: " <> show prop

*** mul2: (nparams: 1 | [IPush 2,IMul]) | mul2': (nparams: 1 | [IAdd,IPush 2,IMul])
    proportion of agreeing runs: 0.0
*** mul2: (nparams: 1 | [IPush 2,IMul]) | mul2': (nparams: 1 | [IPush 2,IDup])
    proportion of agreeing runs: 0.0
*** mul2: (nparams: 1 | [IPush 2,IMul]) | mul2': (nparams: 1 | [IAnd,IMul])
    proportion of agreeing runs: 0.0

#### Proving equivalence part 2: Creating a Symbolic expression whether two programs are equal.

We use the [SBV package](http://hackage.haskell.org/package/sbv) for this part.

we build the code that calls the solver. Here too, we use the `interpProgram` function. Note that we had written it
with a little too much generatily, needing only `Num a` and `Bits a`. Here, we use that generality to run the interpreter on the type `SInt16`, which is a type of _symbolic `int16`_ values. This allows us to neatly get a symbolic expression for the effect
our program has.

If both programs have the same number of parameters, we create `params`, a list of symbolic values, one for each value, which
are universally quantified. 

Next, we run both programs `p1` and `p2` on this _symbolic stack_, consisting of symbols `[p1, p2, ... pn]`. This will give
us a _symbolic representation_ of the programs `p1` and `p2`.

If all went well, we would have the expressions `s1` and `s2` with the symbolic effects of the program. We return the condition that `s1 .== s2`. That is, the effect of the first program must be equal to the effect of the second program.

Whenever we hit some kind of failure mode: the programs do not have the same number of parameters, or are unble to return a final value,
we return `1 .== 0 :: Symbolic SBool`, which indicates failure (`1 = 0` is a constraint that can never be solved, and thus
we will always fail inside the solver).

#### Quick example of proving program equivalence

Let us say we want to prove the program `p1 = [Push 2; PMul]` which take 1 parameter is equivalent to `p2 = [Dup; Add]`. The steps are:
- create the array `params` which has a parameter `forall x. x`.
- run the program `p1`,  which will give us the formula `x * 2`
- run the program `p2`, which will give us the formula `x + x
- assert that `forall x. x * 2 = x + x`


This is sensible, since if the programs have the same effect _for all choices of parameters `x`_, then one can state that they _really are equal_. The universal quantification (`forall x.`) is critical here.

In [81]:
smtQueryEquivProgram :: Program -> Program -> Symbolic SBool
smtQueryEquivProgram p1 p2 = do
  if progNParams p1 /= progNParams p2
  then return $ 1 .== (0 :: SInt8)
  else do
    params <- sequence $ [forall $ "p-" <> show i | i <- [1..progNParams p1]]
    let ms1 = interpInsts (progInsts p1) params :: Maybe SInt8
    let ms2 = interpInsts (progInsts p1) params :: Maybe SInt8
    case liftA2 (,) ms1 ms2 of
      Nothing -> return $ 1 .== (0 :: SInt8)
      Just (s1, s2) -> return $ s1 .== s2

#### Scoring function
We combine the two correctness metrics and our cost metric into `scoreProgram`, which scores how close a program is to another program. 
- Higher scores are better.
- Scores above `2.0` indicate that we passed all correctness checks, and the extra points are present for performance. 
- `3.0` is the highest score possible
- `0.1` is the lowest score possible.

We provide `0.1` as the lowest score, since we use the score later to decide whether this new program
deserves to be picked. We want _all_ programs to have _some_ chance of being picked.

- First, we check if two programs agree on all random inputs with `proportionAgreeingRuns`. `nagree`
  is the proportion of times the two programs _actually_ agree.
- Only if they do (`nagree == 1.0)`, do we make the expensive query to the solver to check that they are equal with `smtQueryEquivProgram`.
- If the programs are really equal, we return `2.0 + <cost penalty>` where a cost of `0` scores `1`, and higher costs score
  exponentially less.
- If the programs are not equal, we return `nagree + 0.1`. This is to ensure that even if the two programs do not agree
  at all (ie, `nagree = 0.0`, we still have some chance of picking the other program.

In [95]:
-- | Type synonym for readability
type Score = Float

-- | Higher score is better.
score :: Program -- ^ baseline program
    -> Program -- ^ newly proposed program `q` whose score we are computing
    -> IO Score -- ^ score of the new program `q`
score p q = do
  nagree <- proportionAgreeingRuns p q
  if nagree /= 1.0
  then return $ 0.1 + nagree
  else do
    res <- sat $ setTimeOut 100 >> smtQueryEquivProgram p q
    if not $ modelExists res
    then return $ 0.1 + nagree
    else return $ 2.0 + 2.0 ** (-1.0 * costProgram q)

#### MCMC

We implement a standard metropolis-hastings algorithm, which uses `scoreProgram` as the proposal distribution. So,
correct programs which run fast are given higher probability in the proposal. 

We have `c` as the original program we are trying to optimise.

We sample a new program `p` in `mhStep` close to the current program `p` using `p' <- perturbProgram p`. Then we calculate the
acceptance ratio as `score p'/ score p`, and accept or reject based on this.

We memoise the computation of the score of the current program `p` by passing a tuple of `(Score, Program)` in `mhStep`.


Finally, `mhTrace` is the main entry point into the MCMC algorithm, which takes the number of samples we want and an initial configuration, and returns to us a list of programs along with their scores.

In [120]:
-- | Take a step of metropolois hastings
mhStep :: Program -- ^ ground truth (concrete)
          -> (Score, Program) -- ^ current position
          -> IO (Score, Program) -- ^ next position
mhStep c (score, p) = do
  p' <- perturbProgram p
  score' <- scoreProgram c p'
  let accept = score' / score
  r <- randfloat (0, 1)
  pure $ if r < accept then (score', p') else (score, p)


-- | Repeat a monadic computation
mRepeat :: Monad m => Int -> (a -> m a) -> (a -> m a)
mRepeat 0 _ = pure
mRepeat n f = f >=> mRepeat (n - 1) f

-- | Return the trace of programs seen and their scores
mhTrace :: Int -- ^ number of samples
        -> Program -- ^ original program (concrete)
        -> IO [(Score, Program)] -- ^ scores
mhTrace n c =
  let nsteps = 10
      -- go :: Int -> (Score, Program) -> M (Score, Program)
      go 0 (s, p) = pure [(s, p)]
      go n (s, p) = do
                      (s', p') <- mRepeat nsteps (mhStep c) $ (s, p)
                      rest <- go (n - 1) (s', p')
                      return $ (s', p'):rest
  in do
    let beginp = c
    s <- scoreProgram c beginp
    go n (s, beginp)

-- | Sample run on mul2
do 
  ps <- mhTrace 8 pmul2
  forM_ ps print

(2.03125,(nparams: 1 | [IPush 2,IMul]))
(2.03125,(nparams: 1 | [IPush 2,IMul]))
(2.015625,(nparams: 1 | [IPush 2,ISwap,IMul]))
(0.1,(nparams: 1 | [IAnd]))
(0.1,(nparams: 1 | []))
(0.1,(nparams: 1 | [IAdd]))
(0.1,(nparams: 1 | [IPush (-6),IDup]))
(0.1,(nparams: 1 | [IAnd]))
(0.1,(nparams: 1 | [IAnd]))

#### Putting it all together

We implement `optimise`, a thin wrapper around `mhTrace`, which invokes `mhTrace` for a 1000 iterations, sorts the
programs in descencing order, takes the best `4` programs and prints them out nicely.

We use `optimise` in `main` to check the optimised versions of some sample programs, and they are indeed what we would expect: 
- `2 + 3` gets [constant folded](https://en.wikipedia.org/wiki/Constant_folding) to `5`
- `\x -> x * 2` gets [strength reduced](https://en.wikipedia.org/wiki/Strength_reduction) to `x + x`
- `\x -> x & x` gets constanr folded to `\x -> x`.

In [127]:
optimise :: Program -> IO ()
optimise c = do
  liftIO $ putStrLn $ "*** original: " <> show c <> "***"
  steps <- mhTrace 1000 c
  let descendingScore (s, _) (s', _) = compare s' s
  let opts = take 4 $ nub $
        sortBy descendingScore [(s, p) | (s, p) <- steps, s >= 2.0]
  forM_ opts $ \(s, p) -> do
    liftIO $ putStrLn $ show (progInsts p) <> " | " <> "score: " <> show s


main :: IO ()
main = do
  optimise $ Program 0 [IPush 2, IPush 3, IAdd]
  optimise $ Program 1 [IPush 2, IMul]
  optimise $ Program 1 [IDup, IAnd]
  
main

*** original: (nparams: 0 | [IPush 2,IPush 3,IAdd])***
[IPush 5] | score: 2.5
[IPush 2,IPush 3,ISwap,IAdd] | score: 2.0625
[IPush 2,IPush 3,ISwap,ISwap,IAdd] | score: 2.03125
*** original: (nparams: 1 | [IPush 2,IMul])***
[IDup,IAdd] | score: 2.25
[IDup,ISwap,IAdd] | score: 2.125
[IDup,ISwap,ISwap,IAdd] | score: 2.0625
[IDup,IDup,IAnd,IAdd] | score: 2.0625
*** original: (nparams: 1 | [IDup,IAnd])***
[] | score: 3.0
[IDup,IAnd] | score: 2.25
[IDup,ISwap,IAnd] | score: 2.125

#### Conclusions

My takeaway from this quick weekend hacking project are twofold:
- MCMC is quite powerful at exploring search spaces, even when the cost function isn't all that "smooth".
- Z3 is quite a good solver, an `SBV` is an _amazingly well designed_ library --- it took me all of 15 minutes to get the library
working!
- Guiding MCMC with extra knowledge seems to be really useful.
- I don't ever need to code up basic compiler transformations again `:)`

I really enjoyed reproducing the `STOKE` paper: It's well written, extremely clear about what they did,
and plain fun to replicate. 

I'd like to expand this to add more instructions, or alternatively, cannibalize the code written here
into my [`tiny-optimising-compiler`](https://github.com/bollu/tiny-optimising-compiler)
project which hopes to be the go-to reference for all newfangled compiler optimisation techniques.