# Solving a counterintuitive probability problem with Markov chains

The weekly Riddler column at FiveThirtyEight posted a **[fascinating (and delicious) chocolate-based probability question](https://fivethirtyeight.com/features/can-you-eat-all-the-chocolates/)**:

>I have 10 chocolates in a bag: Two are milk chocolate, while the other eight are dark chocolate. One at a time, I randomly pull chocolates from the bag and eat them — that is, until I pick a chocolate of the other kind. When I get to the other type of chocolate, I put it back in the bag and start drawing again with the remaining chocolates. I keep going until I have eaten all 10 chocolates.

> For example, if I first pull out a dark chocolate, I will eat it. (I’ll always eat the first chocolate I pull out.) If I pull out a second dark chocolate, I will eat that as well. If the third one is milk chocolate, I will not eat it (yet), and instead place it back in the bag. Then I will start again, eating the first chocolate I pull out.

> What are the chances that the last chocolate I eat is milk chocolate?

Before starting work, I took an intuitive guess. With 20% of the bag started off as milk chocolate, I figured that the probability of ending in milk was greater than 20% because of the "incumbency disadvantage" – chocolates with a higher porportion are more likely to be picked. I also estimated an upper bound of 50% because the dark chocolates had to be more likely than the milk chocolates. For a single point estimate, I averaged the upper and lower bound for a guess of 35%.

What was the answer? As it turns out, **the probability of eating a milk chocolate last is 50%!** In fact, while the code below does not offer an inductive proof of the general case, it looks that any starting distribution of milk and dark chocolates result in an equal probability of being selected last. That is, whether the starting distribution is 2 milk and 8 dark or 2 milk and 30 dark, milk still has a 50% probability of being last!

How did we get there? Below, I'll walk through my Markov chain approach implemented in Python to solve the problem, and subsequently confirm the result with a Monte Carlo simulation. While there simpler approaches to solve the problem (e.g., dynamic programming), taking the roundabout route and being able to draw out the full transition matrix between states has a certain completeness (if lacking elegance).


## Framework of state transitions

We can think about the problem as a transition between *states*. At any point, the bag of chocolates is in a certain state, and this state determines the probabilities of transitioning to any other series of states. Each state represents three pieces of information: the number of remaining milk chocolates, the number of remaining dark chocolates, and what types of chocolate can be eaten in the next state. What do we mean by that last point? Each state meets one of three conditions:

- A milk chocolate was just eaten, so only a milk chocolate can be eaten in the next state. If a dark chocolate is selected, it will be returned to the bag.
- A dark chocolate was just eaten, so only a dark chocolate can be eaten in the next state. If a milk chocolate is selected, it will be returned to the bag.
- The bag has been "reset" and either milk or dark can be eaten next. 

We'll call the first condition "m" (milk), we'll call the condition state "d" (dark), and the last condition "r" (reset). We'll name each state by filling in the shorthand $x$|$y$|$c$, where $x$ denotes the number of milk chocolates, $y$ denotes the number of dark chocolates, and $c$ denotes the current condition of the bag. For example, $2$|$2$|$r$ represents a bag that currently has two milk chocolates, two dark chocolates, and is in the reset state.  

Now that we know how to describe each state, we can start thinking about how states transition to other states, and the associated probabilities.  If the bag is in state $2$|$2$|$r$, what states can it transition to, and with what frequency?

<img src="markov_files/fig1.png" align="center"/>

We're in a "reset" condition, so either milk chocolate or dark chocolate can be eaten next. Because there are an equal number of milk and dark chocolates, they are eaten with equal probability, resulting in a 50% chance of state $1$|$2$|$m$ and state $2$|$1$|$d$.

What happens when we go down one more level?

<img src="markov_files/fig2.png" align="center"/>

If we're in state $1$|$2$|$m$, there are three chocolates left in the bag, and we have a 1/3 chance of drawing a milk chocolate, and a 2/3 chance of drawing a dark chocolate. If we draw a milk chocolate, we eat a milk chocolate, leading to state $0$|$2$|$m$. If we draw a dark chocolate, we don't eat any chocolate, and transition to the reset state $1$|$2$|$r$. The right half of the tree follows a similar logic.

Now with this framework of state transition in mind, we can talk about Markov chains and the solution to the problem.

## Implementation: Using Markov chains

Our diagram above above is an example of a **Markov chain**: a model of a potential sequence of states and the associated probabilities of transitioning between states. The key condition for a process to be a *Markov process* is that the probability of the next state depends only on the current state. This is true in our Markov model because each of our states includes whether we are in an $m$, $d$, or $r$ condition, and so encodes the necessary information to generate the transitions without looking backward. 

The full set of possible transitions and their probabilities is known as a transition matrix. Our transition matrix will start on the $2$|$8$|$r$, as described in the problem instructions. Our final states in the model will have no chocolates left. There are two potential final states of the model, $0$|$0$|$m$ and $0$|$0$|$d$, denoting whether the last chocolate was milk or dark when we have no chocolates remaining. These final states are known in Markov parlance as *absorbing states*. Once these states are reached, they cannot be left, so the Markov process is over. As it happens, our Markov chain also has the property that once any particular state is reached, it will not be reached again (except for absorbing states).

We want to construct the transition matrix because once we have it, we can run some (relatively simple) linear algebra operations, and the result will tell us the probability of transitioning to each of the two absorbing states from any possible state, answering the problem!

The transition matrix will be an $N \times N$ matrix, where $N$ represents the number of potential states. Each row represents a state, and the value in each column represents the probability of transitioning to the state that matches in the index of that column. Here is the transition matrix for the smaller sequence starting with $2$|$2$|$r$:

<img src="markov_files/fig3.png" align="center"/>

The top left white cell at the intersection of $2$|$2$|$r$ is 0 because state $2$|$2$|$r$ has a 0% probability of transitioning to $2$|$2$|$r$ (itself). However, it has a 50% probability of transitioning to $2$|$1$|$d$. Highlighted in yellow are the absorbing states. Each absorbing state has a 100% probability of transitioning to itself.

Most of the cells in the matrix will be 0 because each state only has a nonzero probability of transitioning to either one or two other states.


## Generating the transition matrix

Now we get to the fun. The image above was a sample; we need to generate the full transition starting with $2$|$8$|$r$ (in fact, we will generalize the code, such that it will generate the transition matrix for any starting state). We'll do this via a Python script that finds, for each state, the possible states it can transition to. We can do this fairly simply because the rules are deterministic, and each state can transition to either one or two other states (i.e., only milk chocolate or dark chocolate can be picked from the bag, and we know the probabilities from the ratio of milk chocolate to dark chocolate). 

In the code (available below), we began by generating the two possible states resulting from $2$|$8$|$r$, which are $1$|$8$|$m$ (with probability 20%), and $2$|$7$|$d$ (with probability 80%). We then treat these new states as an input, and generate the new states that can result from these states. For instance, $2$|$7$|$d$ can generate either $2$|$7$|$r$ or $2$|$6$|$d$. We loop and repeat this process of generating states until we have generated all possible states. 

What results is a $68 \times 68$ matrix containing 4,624 cells (with most of them equal to zero). We can then use this matrix to find the probability of eating a milk chocolate last.


## Finding absorption probabilities from each starting state

We can perform linear algebra to find a matrix $B$ that contains, for each state (not just the starting state), the probability of transitioning to each of the two absorbing states. The solution formula and proof is given in chapter 3 of [*Finite Markov Chains*](https://docs.ufpr.br/~lucambio/CE222/1S2014/Kemeny-Snell1976.pdf) by Kemeny and Snell (1960).

Given that:
- $Q$ represents a matrix containing the non-absorbing states of the full transition matrix
- $I$ is the [identity matrix](https://en.wikipedia.org/wiki/Identity_matrix) of $Q$
- $R$ is a matrix containing rows for each of the non-absorbing states and columns for each of the absorbing states

$B$ = $(Q - I)^{-1} \times R$

Because we sorted our transition matrix, the first cell (index [0,0]) of $B$ will be the probability of ending in the absorbing state $0$|$0$|$m$ from initial state $2$|$8$|$r$. This single value is what is returned by our Python model coded below. We can use the function build_model(m = starting milk chocolates, d = starting dark chocolates) to run the output.

In [71]:
build_model(2,8)

0.5

Other initial starting conditions also return a probability of 50%:

In [72]:
build_model(4,8)

0.5

In [73]:
build_model(8,8)

0.5

In [74]:
build_model(20,8)

0.5

At first I couldn't believe the 50% probability! I wrote a Monte Carlo simulation (also given below) to confirm the result. It's certainly a counteruntuitive result and I don't yet have an elegant explanation of why this is the case. I'm looking forward to the Riddler's solutions writeup next week and learning more.

# Code

In [70]:
# Markov transition model

import numpy as np
import pandas as pd

def build_model(m: int, d: int):
    
    """
    Arguments:
    
    m (int): The number of milk chocolates starting in the bag. Must be greater than 0.
    d (int): The number of dark chocolates starting in the bag. Must be greater than 0.
    
    Assumes the starting state is "reset", e.g., either chocolate can be selected
    
    ---
    
    Outputs:
    
    p (float): Probability of milk chocolate being the last chocolate in the bag
    
    ---
    
    We need to start by building the transition matrix. We'll use the function 'add_to_transitions_list' to build the first two entries on the list from the starting state. 
    
    We refer to each state by a shorthand name, m|d|s, meaning 'milk', 'chocolate', 'state'.
    
    For example '2|7|d' refers to a state of 2 milk chocolates left in the bag, 7 dark chocolates, and that the last draw from the bag resulted in a dark chocolate (d). For each state, we want to find the states that it transitions to, adn the associated probabilities of each potential state.
    
    """
    
    assert m > 0 and d > 0, "m and d arguments must be greater than 0"

    
    transitions = []
    add_to_transitions_list(m,d,'r',transitions)

    """
    
    Each transition list generates output states. We iteratively want to treat these output states as inputs to generate the output states from these new states.
    
    We'll repeat this process (iteratively adding elements to a list of transitions and probabilities), until there are no more elements to add to the list.
    
    """
    
    # Initialization of tracker of list length
    trans_len = 1

    # Stop generating the sequence when it doesn't have add any more rows each iteration
    while len(transitions) > trans_len:

        trans_len = len(transitions)

        # New sequence to generate is in seventh element of list
        # Generate as long as there are chocolates remaining (4th or 5th element is greater than 0)
        new_seqs = [j[7] for j in transitions if j[7] not in [i[0] for i in transitions] and max(j[4],j[5]) > 0]
        for i in new_seqs:
            add_to_transitions_list(int(i.split('|')[0]),int(i.split('|')[1]),i.split('|')[2],transitions)
            
        # Reset
        new_seqs = []
        
    """ 
    
    Our absorbing states are missing from the transition list. We'll add these, then we'll do some cleaning to actually convert our list of transitions into a Markov transition matrix.
        
    """    
        
    # Add absorbing states
    transitions.append(['0|0|m', 0, 0, 'm', 0, 0, 'm', '0|0|m', 1])
    transitions.append(['0|0|d', 0, 0, 'm', 0, 0, 'm', '0|0|d', 1])
    
    # Convert to matrix format
    df = pd.DataFrame(transitions,columns = ['identifier','milk_remaining', 'dark_remaining', 'state','milk_remaining_new','dark_remaining_new','state_new','identifier_new','p'])
    df = df.drop_duplicates()
    df_pivot = df.pivot_table(index='identifier', columns='identifier_new', values='p',fill_value=0,dropna=False)

    # Rectangularize the matrix 
    df_pivot[str(m) + '|' + str(d) + '|r'] = 0
    
    # Sort rows and columns
    df_pivot = df_pivot.sort_index(ascending=False)
    df_pivot = df_pivot.sort_index(ascending=False,axis=1)

    
    """
    
    Matrix algebra to find the probability of absorption from each state
    
    """
    
    # 2 appears because it represents the number of absorbing states
    
    matrix = df_pivot.to_numpy()
    
    I = np.eye(len(matrix) - 2)  
    Q = matrix[:-2, :-2]
    R = matrix[:-2,-2:]

    # And the answer is ... (probability of ending in milk from the initial state)
    return round(np.matmul(np.linalg.inv(I - Q), R)[0][0],6)


def add_to_transitions_list(m,d,state,transitions): 
    
    """
    
    Function takes in an input state and adds the output states and probabilities to a transition list for each of those states.
    
    Arguments:
    
    m (int): The number of milk chocolates currently in the bag
    d (int): The number of dark chocolates currently in the bag
    state (str): Either 'm', 'd', or 'r', representing the state before transition
    transitions (list): List to append transitions to
    
    ---
    
    This function finds the transitions from any state. We use the knowledge that there are only two potential transitions (one for milk being selected and one for dark being selected). 
    
     If previous selection was milk, the only two results are "milk" and "reset"
     If previous selection was dark, the only two results are "dark" and "reset"
     If previous selection was reset, the only two results are "milk" and "dark"
     
    """   
    # What happens if next draw is milk?
    
    m_new = m if state == 'd' else m-1
    d_new = d
    p = m/(m+d)
    
    if state == 'm':
        state_new = 'm'
    elif state == 'd':
        state_new = 'r'
    elif state == 'r':
        state_new = 'm'       
    
    if p > 0:
        append_transitions(m,d,state,m_new,d_new,state_new,p,transitions)
        
    # What happens if next draw is dark?

    m_new = m
    d_new = d if state == 'm' else d-1
    p = 1 - (m/(m+d))
    
    if state == 'm':
        state_new = 'r'
    elif state == 'd':
        state_new = 'd'
    elif state == 'r':
        state_new = 'd'
    
    # Choose dark
    if p > 0:
        append_transitions(m,d,state,m_new,d_new,state_new,p,transitions)

        
def append_transitions(m,d,state,m_new,d_new,state_new,p,transitions):

    """   
    
    Function appends collected information about a transition to a list.
    
    Arguments:
    
    m (int): The number of milk chocolates currently in the bag
    d (int): The number of dark chocolates currently in the bag
    state (str): Either 'm', 'd', or 'r', representing the state before transition
    transitions (list): List to append transitions to
    m_new (int): The number of milk chocolates currently in the bag after transition
    d_new (int): The number of dark chocolates currently in the bag after transition
    state_new (str): Either 'm', 'd', or 'r', representing the after before transition
    p (float): Probability of this transition occurring
        
    """
    transitions.append([ 
                     str(m) + '|' + str(d) + '|' + state
                    ,m
                    ,d
                    ,state
                    ,m_new
                    ,d_new
                    ,state_new 
                    ,str(m_new) + '|' + str(d_new) + '|' + state_new
                    ,p 
                    ])        

print(build_model(2,8))



0.5


In [75]:
# Monte Carlo simulation


import random

def find_last_chocolate(chocolates):
    prev_selection = -1
    while len(chocolates) > 1:
        i = random.sample(chocolates,1)[0]
        if prev_selection != i and prev_selection != -1:
            prev_selection = -1
        else:
            prev_selection = i
            chocolates.remove(i)
    return chocolates[0]

iters = 1000000
i = 0
for _ in range(iters):
    i = i + find_last_chocolate([1,1,0,0,0,0,0,0,0,0])
print(i/iters)

0.49974
