# Haskell Kōan: weighted coin tossing in < 100 LoC

We implement a tiny embedded domain-specific language which allows us to _sample from random variables_ and
build computations from them. We also build the less-known but equally important ability to
use _conditional reasoning_, where we can condition a random variable on another random variable.

Implementing this within haskell with a monadic inteface allows us to leverage the full power of Haskell,
making our computations compositional, and our implementation of _conditioning_ concise. 

In [1]:
{-# LANGUAGE GADTs #-}
import Control.Monad (ap, replicateM)
import System.Random (getStdGen, getStdRandom, randomR)
import qualified Data.Map as M
import Control.Monad

First, we create a new GADT called `R`, for *random variable*.
It supports only three operations:

- `Return` lifts a pure value into a random variable which takes on only one value --- that is, a variable 
- `Uniform` a uniformly distributed random variable over `[0, 1]`.
- `Weigh` to scale the probability of getting a value by a custom scaling factor.

It supports a `Functor`, `Applicative`, and `Monad` instance which we implement, to easily
build up random computations.

Then, we write `runRUnweighted` which shows how to run an `R a` computation to get a _random_ `a` value, 
without taking into account the weights (hence, `unweighted`).

In [2]:
data R a where
  Return :: a -> R a -- ^ lift a pure value
  Uniform :: (Float -> R a) -> R a -- ^ uniformly distributed random variable over [0, 1]
  Weigh :: Float -> R a -> R a   -- ^ scale the probability of the computation by a factor
  
instance Functor R where
  fmap f (Return x) = Return (f x)
  fmap f (Uniform rand2m) = Uniform $ \r -> fmap f (rand2m r)
  fmap f (Weigh w m) = Weigh w (fmap f m) 
  
instance Monad R where
  return = Return
  Return a >>= f = f a
  Uniform rand2m >>= f = Uniform $ \r -> rand2m r >>= f
  (Weigh w m) >>= f = Weigh w (m >>= f)
  
instance Applicative R where
  pure = return
  (<*>) = ap
  
-- | Run a random computation, *not taking into account the weights*.
-- | This will be used to run the _traced computation_, which
-- | does take into account the weights.
runUnweighted :: R a -> IO a
runUnweighted (Return a) = return a
runUnweighted (Weigh w m) = runUnweighted m
runUnweighted (Uniform rand2m) = do
  r <- getStdRandom $ randomR (0, 1)
  runUnweighted (rand2m r) 
  

-- | A value that is uniformly distributed over (0, 1). Convenient constructor
uniform01 :: R Float
uniform01 = Uniform Return

-- \ Quick run of the uniform sampler to see what it outputs.
runUnweighted (replicateM 5 uniform01) >>= \xs -> putStrLn $ "uniform random values: " <> show xs

-- | Change the weight of the rest of the computation. As of now, we cannot
-- interpret this.
weigh :: Float -> R ()
weigh w = Weigh w (Return ())

uniform random values: [0.79149634,0.7387544,0.7090056,0.11956179,0.1219663]

Now that we have a way to get random numbers _uniformly_ from the range (0, 1) we'll use this to build
more complicated random variables --- namely, coins with different biases.

`coin p` creates a coin that returns `1` with probability `p`, and `0` with probabilty `(1 - p)`.
We pick a number `r` uniformly from `(0, 1)`. If this number is less than `p`, we return `1`, otherwise we
return `0`. We can convince ourselves that this is right by looking at some exreme cases:

- Consider `coin 0`. Now,  `r < 0` is impossible since `r` is in `[0, 1]`. Hence, _never_ return a `1` (that is,
we return `1` with probability `0`)
- Vice versa, considr `coin 1`. `r < 1` is always satisfied, since `r` is picked from `[0, 1]`. Hence, we will _always_ return a `1`.

Similar reasoning for `coin 0.5` will lead to us argue that we will return both `1` and `0` with probability `0.5`, and this argument can be generalized to a `coin p`.

In [3]:
-- | 'coin p' returns 1 with probability p, 0 with probability (1 - p) 
coin :: Num a => Float -> R a
coin p = do
  r <- uniform01
  return $ if r < p then 1 else 0
  
runUnweighted (replicateM 10 (coin 0)) >>= \xs -> putStrLn $ "coin 0: " <> show xs
runUnweighted (replicateM 10 (coin 1)) >>= \xs -> putStrLn $ "coin 1: " <> show xs
runUnweighted (replicateM 10 (coin 0.5)) >>= \xs -> putStrLn $ "coin 0.5: " <> show xs


-- | Chose a discrete value with uniform probability
discrete :: [a] -> R a
discrete as = do
  r <- uniform01
  let ix = floor $ r * (fromIntegral $ length as)
  return $ as !! ix
  
runUnweighted (replicateM 10 (discrete [1, 10, 100])) >>= \xs -> putStrLn $ "discrete: " <> show xs

coin 0: [0,0,0,0,0,0,0,0,0,0]

coin 1: [1,1,1,1,1,1,1,1,1,1]

coin 0.5: [1,1,1,1,1,0,1,1,0,1]

discrete: [100,10,10,1,10,10,100,1,10,10]

So, while the above code for `coin` works, we need to think about the correct condition of `r < p`, and it
is not immediate. If we had access to `weigh`, here is an alternate way we could implement `coin`. We first
implement this `coin'`, which is `coin` implemented using `weigh`, and then build the infrastructure
necessary for this to work.

In [4]:
 -- | A biased coin, created from a fair coin.
coin' :: (Eq a, Num a) => Float -> R a
coin' b = do
  -- | pick heads or tails with uniform probability
  fair <- discrete [0, 1]
  -- | if the fair coin landed 1...
  if fair == 1
  then weigh $ b -- weigh the outcome by `b`.
  else weigh $ (1 - b) -- otherwise weigh the outcome by `1 - b`.
  -- | return the value that was tossed, with the new weight.
  return fair
  
runUnweighted (replicateM 10 (coin' 0)) 
  >>= \xs -> putStrLn $ "coin' 0 (will not work, since we do not have weigh): " <> show xs

coin' 0 (will not work, since we do not have weigh): [0,1,1,0,1,1,0,1,0,1]

So, for this, we need the ability to take into account the calls to `weigh` we have in the code.
For this, we use a technique described first in the [Church programming language paper](https://web.stanford.edu/~ngoodman/papers/churchUAI08_rev2.pdf), and also explained in the paper [Denotational validation of higher-order bayesian inference](https://arxiv.org/abs/1711.03219). 

The idea is to sample from the space of "program traces", where a `Trace` keeps track of:
- The final output value --- `traceval`
- All the randomness used in producing this output value (the results of all `Uniform` invocations) --- `tracerands`
- All weighting that has been done on the output value (the product of all `weigh`s found along this computational tract) --- `traceweight`.

We store the traces in a `Trace a` object. 

In [21]:
-- | Trace of computation 
data Trace a = 
  Trace { traceval :: a
        , tracerands :: [Float]
        , traceweight :: Float
        }

-- | Lift a value to a trace. Start it with weight 1.0, no randomness, and the given value
liftTrace :: a -> Trace a
liftTrace a = Trace a [] 1.0

-- | Weigh a trace by the given weight 
weighTrace :: Float -> Trace a -> Trace a
weighTrace w tr = tr {traceweight=(traceweight tr)*w}

-- | Record the use of randomness along the trace.
recordRandomnessTrace :: Float -> Trace a -> Trace a
recordRandomnessTrace r tr = tr {tracerands=r:(tracerands tr)}

-- | given a regular computation, edit the computation to trace
-- | the computation. 
traceR :: R a -> R (Trace a)
traceR (Return x) = Return $ liftTrace x
traceR (Uniform rand2m) = 
  Uniform $ \r -> recordRandomnessTrace r <$> (traceR $ rand2m r)
traceR (Weigh w m) = 
 Weigh w $ weighTrace w <$> (traceR m) 


-- | Run the random variable, using the randomness provided until the
-- | randomness is exhausted
runRWithRandomness :: [Float] -> R a -> R a
runRWithRandomness _ (Return a) = Return a
-- | Feed the Uniform sampling the randomness we have, and continue running
-- with the rest of the randomness
runRWithRandomness (r:rs) (Uniform rand2m) =
  Uniform $ \r -> runRWithRandomness rs (rand2m r)
-- | ran out of randomness
runRWithRandomness [] (Uniform rand2m) = Uniform rand2m 
runRWithRandomness rs (Weigh w m) = 
  Weigh w (runRWithRandomness rs m)

-- | sample from the computation till we find a trace with
-- non zero weight
nonZeroWeightTrace :: R (Trace a) -> R (Trace a)
nonZeroWeightTrace mt = do
  t <- mt
  if traceweight t == 0
  then nonZeroWeightTrace mt
  else return $ t
  
-- | Take samples from the traced random variable using traced monte caro
tracedMhStep :: R (Trace a) -> Trace a -> R (Trace a)
tracedMhStep mt t = do
  -- | Pick a random position in the randomness of the original trace
  ix <- discrete [0..(length $ tracerands t) - 1]
  -- | Edit the trace at this position by changing the randomness
  r <- uniform01                  
  let (randl, randr) = splitAt ix (tracerands t)
  -- \ replace the randomness of the trace at this position with this
  -- new random value, and now re-run the computation
  let newrand = randl ++ [r] ++ drop 1 randr 
  -- | re-run the old computation, by feeding it the new randomness
  t' <- runRWithRandomness newrand mt 
     
  let ratio = traceweight t' * (fromIntegral . length . tracerands $ t') /
               traceweight t * (fromIntegral . length . tracerands $ t)
  accept <- uniform01
  return $ if accept < ratio then t' else t

-- | Repeat the monadic computation n times
loopM :: Monad m => Int -> (a -> m a) -> a -> m a
loopM 0 _ a = return a
loopM n f a = f a >>= loopM (n - 1) f


-- | Take samples from a random variable by using traced metropolois hastings   
tracedMH :: Int -> R a -> R [a]
tracedMH n m = do
  -- | create the traced randomness source, and sample fromt ti till we get an acceptable
  -- | computation
  let tm = traceR m
  t <- nonZeroWeightTrace tm
  -- | Int -> Trace a -> R [Trace a]
  let go 0 t = pure []
      go n t = do
         t' <- loopM 10 (tracedMhStep tm) $ t
         ts <- go (n - 1) t'
         return $ t:ts
  traces <- go n t
  return $ map traceval traces
  
-- | get N samples from a random varaible that uses `weigh`, by sampling using traced metropolis hastings.
runsWeighted :: Int -> R a -> IO [a]
runsWeighted n m = runUnweighted $ tracedMH n m

Let's also write a tiny utility to plot many values onto the command line with fancy
ASCII-art. This is useful when we want to sample a _large number of things_ and look at
histograms with `histogram`, or look at the values, with `printvals`.

In [22]:
-- | List of characters that represent sparklines
sparkchars :: String
sparkchars = "_▁▂▃▄▅▆▇█"

-- Convert an int to a sparkline character
num2spark :: RealFrac a => a -- ^ Max value
  -> a -- ^ Current value
  -> Char
num2spark maxv curv =
   sparkchars !!
     (floor $ (curv / maxv) * (fromIntegral (length sparkchars - 1)))

-- | Print sparklines with title
printvals :: RealFrac a => String -> [a] -> IO ()
printvals title vs = do 
  let maxv = maximum vs
  putStrLn $ title ++ " " ++ map (num2spark maxv) vs
  
-- | Create a histogram from values.
histogram :: RealFrac a
          => String -- ^ title
          -> Int -- ^ number of buckets
          -> [a] -- values
          -> IO ()
histogram title nbuckets vs = do
        let minv = minimum vs
            maxv = maximum vs
            perbucket = (maxv - minv) / (fromIntegral nbuckets)
            bucket v = floor ((v - minv) / perbucket)
            bucketed = M.fromListWith (+) [(bucket v, 1) | v <- vs]
        printvals title $ M.elems $ bucketed

We can now draw our previous coins. a vertical bar means that we got a `1`, and not having a vertical bar
means that we got a value `0`. We would expect `bias 0` to have no vertical bars (since we should never get a `1`). Similarly, we would expect `bias 1` to have only vertical bars (since we should always get a `1`).


Let's use the machinery we have build to bring `coin'` online (recall that `coin'` was defined using `weigh`,
and did not work with our previous `runUnweighted`). We will run both the `coin` and `coin'` for different
biases to compare, and check that their outputs look roughly the same.

Also note that from this point onward, we will _only use_ `runsWeighted`, since our weighted sampler
completely subsumes the unweighted sampler.

In [23]:
runsWeighted 20 (coin 0) >>=  printvals "coin: bias 0"
runsWeighted 20 (coin' 0) >>=  printvals "coin': bias 0"

putStrLn "---"
runsWeighted 20 (coin 0.2) >>=  printvals "coin: bias 0.2"
runsWeighted 20 (coin' 0.2) >>=  printvals "coin': bias 0.2"

putStrLn "---"
runsWeighted 20 (coin 0.5) >>=  printvals "coin: 0.5"
runsWeighted 20 (coin' 0.5) >>=  printvals "coin': 0.5"

putStrLn "---"
runsWeighted 20 (coin 0.8) >>=  printvals "coin: 0.8"
runsWeighted 20 (coin' 0.8) >>=  printvals "coin': 0.8"

putStrLn "---"
runsWeighted 20 (coin 1) >>=  printvals "coin: 1"
runsWeighted 20 (coin' 1) >>=  printvals "coin': 1"

coin: bias 0 ____________________

coin': bias 0 ____________________

---

coin: bias 0.2 █_█_█________█__██__

coin': bias 0.2 █_______█____█______

---

coin: 0.5 █__█████_█_█_█__█___

coin': 0.5 _█____████_█████__██

---

coin: 0.8 ███_██████████_█████

coin': 0.8 _████__█████_███████

---

coin: 1 ████████████████████

coin': 1 ████████████████████

Next, we use the `weigh` mechanism to sample from _any shape of distribution we want_. The idea is this: if we want to sample points with a distribution `d :: Float -> Float`, we will sample uniformly a value `r` in the range `(lo, hi)`, and then _weigh this `r`_ by the distribution `d`.  

This allows us to sample from shapes such as:
- $f(x) = x^2$
- $f(x) = |\sin x|$
- $f(x) = e^{-x^2}$ (gaussian)

In [20]:
distributionToR :: (Float, Float) -> (Float -> Float) -> R Float
distributionToR (lo, hi) d = do
  r <- uniform01
  let val = lo + r * (hi - lo)
  -- | weigh the sample `val` with weight `d val
  weigh $ d val
  -- | return the value, with the new weight applied.
  return $ val


runsWeighted 1000 (distributionToR (0, 6) (^2)) >>= histogram "x^2" 25
runsWeighted 1000 (distributionToR (0, 6) (abs . sin)) >>= histogram "|sin x|" 25
runsWeighted 1000 (distributionToR (-6, 6) (\x -> exp (-1.0 * x * x))) >>= histogram "e^{-x^2}" 25
runsWeighted 1000 (distributionToR (-2, 2) (\x -> abs (1 - x*x))) >>= histogram "|1-x^2|" 25

x^2 ______▁__▁▁▁▂▂▃▂▃▃▄▄▅▇▆▇█_

|sin x| _▂▄▅▆▇█▆▅▅▃▃_▁▃▃▅▇▇▅▇▅▅▄▃_

e^{-x^2} ___▁▂▄▅█▆▅▄▃▂____

|1-x^2| █▅▄▃▂__▁▂▂▂▃▂▂▃▂▂▁__▂▂▅▅▇_

Great, all of them seem to work. Let's now use similar ideas to estimate the bias of a coin. The idea is as follows:

We have a model of a coin, and we know how likely heads or tails is given the model of a coin. Let's call this
$P(d|m)$ (probability of the $d$ata given the $m$odel). This is:

$$
\begin{align}
P(1|\text{bias}) &= \text{bias} \\
P(0|\text{bias}) &= 1 - \text{bias}
\end{align}
$$

What we want to do is the _inverse problem_.
Given observations about coin flips from a coin with an _unknown bias_, we wish to _predict its bias_.
That is, we want to find $P(\text{bias}|\text{data})$.

We solve this problem using Bayes' theorem. We know that

$$P(bias|data) = \frac{P(data|bias) P(bias)}{P(data)}$$ 

The denominator is normalzation factor that is constant for a fixed dataset. Thus, we write:

$$P(bias|data) \propto P(data|bias) P(bias)$$

This, if $P(bias)$ is our _prior belief_ about the biases, then $P(data|bias)$ is how much we need to multiply the
prior with to get the _posterios belief_.

In our case, as mentioned above, the value of $P(data|bias)$ is:

$$
\begin{align}
P(1|\text{bias}) &= \text{bias} \\
P(0|\text{bias}) &= 1 - \text{bias}
\end{align}
$$

In [26]:
-- | Given a list of observations from a coin and the bias, return a value proportional
-- to the coin having that bias. Find this by multiplying by bias if we have a 1, (1 - bias)
-- if we have a 0, for each heads/tails we see.
estimateBias :: [Int] -> R Float
estimateBias obs = do
  b <- uniform01 -- ^ Uniform prior
  forM_ obs $ \ob -> 
    if ob == 1 then weigh b else weigh (1 - b)
  return $ b
  
replicateList :: Int -> [a] -> [a]
replicateList n as = mconcat $ replicate n as 

runsWeighted 1000 (estimateBias []) >>= histogram "estimate with no data" 10
runsWeighted 1000 (estimateBias [1]) >>= histogram "estimate with [1]" 10
runsWeighted 1000 (estimateBias [0]) >>= histogram "estimate with [0]" 10
runsWeighted 1000 (estimateBias [0, 1]) >>= histogram "estimate with [0, 1]" 10
runsWeighted 1000 (estimateBias [1, 0]) >>= histogram "estimate with [1, 0]" 10
runsWeighted 1000 (estimateBias [1, 0, 1, 0]) >>= histogram "estimate with [1, 0]x2" 10
runsWeighted 1000 (estimateBias (replicateList 8 [1, 0])) >>= histogram "estimate with [1, 0]x8" 10
runsWeighted 1000 (estimateBias (replicateList 20 [1, 0])) >>= histogram "estimate with [1, 0]x20" 10

estimate with no data ▇▇▆▇█▆▆▇▆▇_

estimate with [1] _▁▂▃▃▅▅▇▇█_

estimate with [0] █▆▅▅▄▃▂▁▁__

estimate with [0, 1] ▁▄▆▇█▆▇▅▃▁_

estimate with [1, 0] ▁▄▅▇▇█▇▆▄▁_

estimate with [1, 0]x2 _▂▄▅▇█▆▅▂▁_

estimate with [1, 0]x8 ___▂▅█▇▄▂__

estimate with [1, 0]x20 __▂▆█▅▂__

Notice the way in which we update the bias: We start with the notion that _any bias is equally likely_,
since we pick `b <- uniform01`. Then, as we read the observations, we update the _weight of that given bias `b`_
based on Bayes' rule --- if the observation is `1`, we scale with `b`, and if not, we scale with `1 - b`. 