### Introduction
In this notebook we will implement value iteration and use it to solve the gambler's problem and GridWorld. This is based on the book by Sutton-Barto, specifically this version: https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf

### Choosing a $\gamma$
Recall that $\gamma \in [0, 1]$ is the discount factor used to compute the expected discounted returns
$$G_t = \sum_{k = 0}^\infty \gamma^k R_{t + k + 1}.$$
This hyperparameter $\gamma$ controls how farsighted the agent is, and it also ensures that $G_t$ is finite in the case of infinitely many bounded nonzero terms. (See Sections 3.3, 3.4 in Sutton-Barto.) Choose $\gamma$ for the present case of a finite MDP.

In [None]:
GAMMA = 1.0

### General objects
Here is a general class template that represents a finite Markov decision process (MDP).

In [None]:
import random
import math
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm

class MDP:

    @property
    def states(self):
        '''
        Returns:
            (set) The states of this MDP.
        '''
        pass

    def actions(self, state):
        '''
        Returns:
            (set) The actions allowed from `state`.
        '''
        pass

    def psr(self, state, action):
        '''
        Returns:
            (set) The set of tuples (p, s, r) where `s` is a state reachable
                from `state` by performing `action`, `p` is the probability
                of reaching `s`, and `r` is the reward gained by reaching `s`.
                In particular, the sum of the `p` equals 1.
        '''
        pass

    @property
    def sink_states(self):
        return {
            state
            for state in self.states
            if all(
                psr[1] == state # every psr goes back to state
                for action in self.actions(state)
                for psr in self.psr(state, action)
            )
        }

    def approximate_value(self, policy, state, n_samples=int(1e3)):
        if n_samples == 1:
            curr_state = state
            total_reward = 0
            while curr_state not in self.sink_states:
                # choose an action randomly according to the policy
                actions, probs = list(zip(*policy[curr_state].items()))
                action = random.choices(actions, weights=probs, k=1)[0]
                # choose a psr randomly according to the MDP
                p_list, s_list, r_list = list(zip(*self.psr(curr_state, action)))
                s, r = random.choices(list(zip(s_list, r_list)), weights=p_list, k=1)[0]
                total_reward += r
                curr_state = s  
            return total_reward
        else:
            return sum(
                self.approximate_value(policy, state, n_samples=1)
                for _ in tqdm(range(n_samples))
            ) / n_samples

### Implementing the gambler's problem
To make our implementation compatible with value iteration, we will include the state `self.goal`. Recall that the reward of reaching the goal is $1$ and that the reward of reaching any other state is $0.$ Moreover, recall that the MDP ends once either `0` or `self.goal` is reached. How should the state `self.goal` behave, e.g. what should the available actions and rewards from this state be? Complete the below implementation.

In [None]:
class GamblersProblem(MDP):
    '''
    The gambler's problem, as described in Example 4.3 of Sutton-Barto.
    '''
    
    def __init__(self, p_h, goal):
        assert 0 <= p_h <= 1
        self.p_h = p_h
        self.goal = goal

    @property
    def states(self):
        return set(range(self.goal + 1))

    def actions(self, state):
        if state in {0, self.goal}:
            # don't allow any actions once the goal is reached, since otherwise
            # one could reach the goal multiple times and gain repeated rewards
            return {0}
        else:
            # don't allow bets of zero
            return set(range(1, min(state, self.goal - state) + 1))

    def psr(self, state, action):
        return {
            (
                self.p_h, # heads, i.e. the gambler wins money
                state + action,
                1 if state + action == self.goal != state else 0
            ), (
                1 - self.p_h, # tails, i.e. the gambler loses money
                state - action,
                0
            ) 
        }

    def plot_policy(self, policy):
        for s, actions in policy.items():
            xy = [(s, action) for action in actions]
            plt.scatter(
                *zip(*xy),
                # purple if there are multiple actions
                color='purple' if len(actions) > 1 else 'blue', 
                marker='.'
            )

    def plot_value(self, value):
        xy = [(s, v) for s, v in value.items()]
        plt.plot(*zip(*xy))


assert GamblersProblem(0.45, 100).sink_states == {0, 100}

### Value functions

Here is a general class that represents a value function on a finite MDP. Implement the `truncated_evaluations` method, recalling that the truncated evaluation (see Section 4.4 in Sutton-Barto) for a value function $v,$ a state $s,$ and an action $a$ is
$$\mathbb{E}[R_{t + 1} + \gamma v(S_{t + 1}) \mid S_t = s, A_t = a] = \sum_{(s', r)} p(s', r \mid s, a)[r + \gamma v(s')].$$

In [None]:
class Value(dict):

    def truncated_evaluations(self, mdp):
        '''
        Returns:
            (dict) The truncated policy evaluations for every state and
                every action. The keys are the states, and the values are
                dicts keyed by actions with value the truncated policy
                evaluation of performing the action at the state.
        '''
        return {
            s: {
                a : sum(
                    p * (r + GAMMA * self[s_prime])
                    for p, s_prime, r in mdp.psr(s, a)
                )
                for a in mdp.actions(s)
            }
            for s in mdp.states
        }

    def greedy_policy(self, mdp, tiebreak='equal_split'):
        '''
        The greedy policy (with respect to the truncated policy evaluations).

        Args:
            tiebreaker: Decides what to do in case there are multiply actions
                with the same truncated policy evaluation. The options are
                'equal_split' (of the probabilities), 'max' (of the actions),
                'min' (of the actions), 'random' (action).
        '''
        greedy_actions = {
            s: {a for a, v in evals.items() if v == max(evals.values())}
            for s, evals in self.truncated_evaluations(mdp).items()
        }
        tiebreaker = lambda actions: {
            'equal_split': {a: 1 / len(actions) for a in actions},
            'max': {max(actions): 1.0},
            'min': {min(actions): 1.0},
            'random': {random.choice(list(actions)): 1.0},
        }[tiebreak]
        return {
            s: tiebreaker(actions)
            for s, actions in greedy_actions.items()
        }

### Value iteration algorithm
Implement the value iteration algorithm in a Pythonic way. (See Figure 4.5 in Sutton-Barto on page 101, but implement it in a sane way.)

In [None]:
def value_iteration_algorithm(mdp, theta=1e-15, max_iter=2000, tiebreak='equal_split'):
    '''
    The value iteration algorithm. 
    '''
    v = Value({state: 0 for state in mdp.states})
    pbar = tqdm(range(max_iter))
    for iter in pbar:
        tpe = v.truncated_evaluations(mdp)
        new_v = Value({s: max(tpe[s].values()) for s in mdp.states})
        Delta = max(abs(v[s] - new_v[s]) for s in mdp.states)
        v = new_v
        if float(Delta) < theta:
            print(f'{Delta=:.3} reached threshold {theta=:.3}')
            break
        mdp.plot_value(v)
        pbar.set_description(f'{Delta=:.3}')
    plt.show()
    return v.greedy_policy(mdp, tiebreak=tiebreak), v

### Playing around with the gambler's problem

We will use a goal of $32$ instead of $100$ because it is faster to compute. For $p_h = 0.25,$ value iteration converges within 10 iterations and produces a nontrivial policy involving betting everything at the state $16.$ We can approximate the value of this strategy in two ways: using the truncated evaluations of the policy and using monte carlo simulations (`MDP.approximate_value`). The truncated evaluation is exactly $0.25$ whereas the monte carlo simulation is approximately $0.25$ (this is just sampling a Bernoulli random variable), corresponding to the probability of winning the coin flip at the state $16.$

In [None]:
p_h = 0.25
mdp = GamblersProblem(p_h, 32)
policy, value = value_iteration_algorithm(mdp, theta=1e-2, tiebreak='equal_split')
mdp.plot_policy(policy)
plt.show()

print(f'truncated evaluation value is {value[16]}')
print(f'monte carlo approximated value is {mdp.approximate_value(policy, 16, n_samples=int(1e4))}')

For $p_h = 0.55,$ consider stopping the value iteration early so that it produces a nontrivial policy involving bets of greater than 1 dollar. The value of this non-optimal policy is significantly lower than the optimal policy of always betting exactly 1 dollar.

In [None]:
p_h = 0.55
mdp = GamblersProblem(p_h, 32)
policy, value = value_iteration_algorithm(mdp, theta=1e-2, tiebreak='equal_split')
mdp.plot_policy(policy)
plt.show()

print(f'monte carlo approximated value is {mdp.approximate_value(policy, 16, n_samples=int(1e3)):.3}')
print(f'truncated evaluation value is {value[16]:.3}')

In [None]:
p_h = 0.55
mdp = GamblersProblem(p_h, 32)
policy, value = value_iteration_algorithm(mdp, theta=1e-3, tiebreak='equal_split')
mdp.plot_policy(policy)
plt.show()

print(f'monte carlo approximated value is {mdp.approximate_value(policy, 16, n_samples=int(200))}')
print(f'truncated evaluation value is {value[16]}')

### GridWorld
Here is an implementation of GridWorld, introduced in Example 3.8 on page 72 of https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf. Because this MDP lasts forever for reasonable policies, we will choose `GAMMA = 0.9` to get finite values. To do this, change `GAMMA` at the beginning of the notebook, and rerun all cells above this one (lmao).

In [None]:
import numpy as np

assert GAMMA < 1, 'Set GAMMA to 0.9 at the top of the notebook, then rerun all cells. (See text above.)'

class GridWorld(MDP):
    def __init__(self, n, portals):
        self.n = n
        self.portals = portals

    @property
    def states(self):
        return {(i, j) for i in range(self.n) for j in range(self.n)}

    def actions(self, state):
        return {(-1, 0), (0, 1), (1, 0), (0, -1)} # NESW

    def in_grid(self, i, j):
        return 0 <= i < self.n and 0 <= j < self.n

    def psr(self, state, action):
        # note there is only one psr because the result of an action is always determinisitic
        p = 1.0
        
        if state in self.portals:
            s, r = self.portals[state] # travel through a portal and get the reward for it
        else:
            i, j = state
            di, dj = action
            if self.in_grid(i + di, j + dj):
                s = (i + di, j + dj) # simply move according to the action
                r = 0.0
            else:
                s = state # don't move
                r = -1.0 # get penalized by 1.0 for trying to move out of the grid
        return {(p, s, r)}

    def plot_value(self, value):
        '''Don't plot value functions'''
        pass

    def plot_policy(self, policy, value=None):
        '''It works, so don't worry about it.'''
        ij_to_plt = lambda i, j: (j, self.n - i)
        arrow = {(0, 1) : (1, 0), (0, -1) : (-1, 0), (-1, 0) : (0, 1), (1, 0) : (0, -1)}
        fig, ax = plt.subplots(figsize=(self.n + 1, self.n + 1))
        plt.axis('off')
        for (i, j), actions in policy.items():
            for action in actions:
                plt.arrow(
                    j, self.n - i,
                    0.2 * arrow[action][0], 0.2 * arrow[action][1],
                    head_width=0.05
                )
        colors = plt.get_cmap('spring')(np.linspace(0, 1, len(self.portals)))
        for (portal_start, (portal_end, reward)), color in zip(self.portals.items(), colors):
            plt.scatter(*ij_to_plt(*portal_start), s=100, color=color)
            plt.scatter(*ij_to_plt(*portal_end), s=400, color=color)
            plt.text(*ij_to_plt(*portal_start), str(reward), horizontalalignment='left', verticalalignment='bottom', size=20)
        if value is not None:
            for state, v in value.items():
                plt.text(*ij_to_plt(*state), str(round(v, 1)), horizontalalignment='right', verticalalignment='top', size=10)

We can easily solve the example given in lecture and in the book. Here is how to interpret the plots:
- colored disks represent the portals, where the small ones are the starts and the big ones are the destinations
- large numbers indicate the reward for entering the portal (i.e. being at the portal start and taking the unique action to the portal destination)
- arrows represent the equally-likely actions of an optimal policy
- small numbers are the values of that optimal policy.

In [None]:
mdp = GridWorld(5, {(0, 1): ((4, 1), 10), (0, 3): ((2, 3), 5)})
policy, value = value_iteration_algorithm(mdp, theta=1e-5, tiebreak='equal_split')
mdp.plot_policy(policy, value=value)

Consider a one-way torus, and fix a point $p$ on it. What is the shortest path to $p$? Simply follow the arrows in the folliwing policy.

In [None]:
n = 16
p_i, p_j = (4, 3)

portals_from_right_to_left = {(i, n - 1): ((i, 0), 0) for i in range(n)}
portals_from_bottom_to_top = {(n - 1, i): ((0, i), 0) for i in range(n - 1)}
destination = {(p_i, p_j): ((p_i, p_j), 1)}
mdp = GridWorld(
    n,
    {**portals_from_right_to_left, **portals_from_bottom_to_top, **destination}
)
policy, value = value_iteration_algorithm(mdp, theta=1e-5, tiebreak='equal_split')
mdp.plot_policy(policy, value=value)

How did Walter White decide to pursue a life of crime? He took into account his life expectancy (`GAMMA`) and the expected pay. In the following cell, try adjusting the various pay levels:

In [None]:
n = 16
normal_pay = 1
small_crime_pay = 4
big_crime_pay = 16

small_crime = { # pay is a little better than usual
    (4, 13): ((11, 12), small_crime_pay),
    (3, 12): ((3, 13), normal_pay),
    (3, 14): ((4, 14), normal_pay),
    (5, 14): ((5, 13), normal_pay),
    (5, 12): ((4, 12), normal_pay),
}

big_crime = { # pay is way better than usual
    (11, 4): ((12, 11), big_crime_pay),
    (10, 3): ((10, 4), normal_pay),
    (10, 5): ((11, 5), normal_pay),
    (12, 5): ((12, 4), normal_pay),
    (12, 3): ((11, 3), normal_pay),
}


jail_bounds = { # you cannot leave
    (i, j): ((i, j), -100) 
    for i in range(n)
    for j in range(n)
    if 10 <= i <= 13 and 10 <= j <= 13 and (i in {10, 13} or j in {10, 13})
}

mdp = GridWorld(
    n,
    {**small_crime, **big_crime, **jail_bounds}
)
policy, value = value_iteration_algorithm(mdp, theta=1e-5, tiebreak='equal_split')
mdp.plot_policy(policy, value=value)

You can create your own grid-world here, if you are feeling creative:

In [None]:
n = 16
portals = {
    # start_location: (end_location, reward)
    (1, 5): ((13, 8), 10)
}
mdp = GridWorld(n, portals)
policy, value = value_iteration_algorithm(mdp, theta=1e-5, tiebreak='equal_split')
mdp.plot_policy(policy, value=value)