## Dynamic programming



In [27]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import poisson

%matplotlib inline

In [197]:
class Gridworld():
    "Toy MDP with known dynamics. Textbook example 4.1."
    
    def __repr__(self):
        return "Instance of a Gridworld:\n{}".format(self.world)
            
    def __init__(self, size=(4,4)):
        "Actions are UP=0, DOWN=1, LEFT=2, RIGHT=3."
        if len(size) != 2:
            raise ValueError('size must be 2 dimensional, '
                             'got size={}'.format(size))
        # states
        self.size = size
        self.n_states = np.product(self.size)
        self.states = np.arange(self.n_states)
        self.terminal_states = np.array([0, self.n_states-1])
        self.nonterminal_states = self.states[1:-1]
        self.world = self.states.reshape(self.size)
        
        # dynamics
        self.UP = 0
        self.DOWN = 1
        self.LEFT = 2
        self.RIGHT = 3
        self.n_actions = 4
        self._define_dynamics()
    
    def coordinates(self, state):
        """Return grid coordinates of state."""
        row, col =  np.argwhere(self.world == state).flatten()
        return row, col
    
    def reward(self, state, action, next_state):
        """Return reward of -1 unless arriving in terminal state."""
        return 0 if state in self.terminal_states else -1
    
    def next_state(self, state, action):
        """Return next_state given state and action."""
        row, col = self.coordinates(state)
        next_coords, probability = self.dynamics[row][col][action]
        return self.world[next_coords], probability
        
    def _define_dynamics(self):
        """Define next_state given state and action."""
        prob = 1  # always move deterministically
        self.dynamics = []
        for i, row in enumerate(self.world):
            self.dynamics.append([])
            for j, element in enumerate(row):
                next_state = {}
                if i == 0:  # top 
                    next_state[self.UP] = ((i, j), prob)
                else: 
                    next_state[self.UP] = ((i-1, j), prob)

                if i == (self.size[0] - 1):  # bottom
                    next_state[self.DOWN] = ((i, j), prob)
                else:
                    next_state[self.DOWN] = ((i+1, j), prob)

                if j == 0:  # left
                    next_state[self.LEFT] = ((i, j), prob)
                else:
                    next_state[self.LEFT] = ((i, j-1), prob)

                if j == (self.size[1] - 1):  # right
                    next_state[self.RIGHT] = ((i, j), prob)
                else:
                    next_state[self.RIGHT] = ((i, j+1), prob)

                self.dynamics[i].append(next_state)    


In [198]:
gridworld = Gridworld()

### Policy evaluation

In [199]:
def evaluate_policy(policy, mdp, discount=0.99, threshold=0.001):
    """Policy evaluation alogorithm given finite mdp with known dynamics."""
    V = np.zeros_like(mdp.states, dtype=np.float32)
    while True:
        delta = 0
        for state in mdp.nonterminal_states:
            v = 0
            for action, prob_action in enumerate(policy[state]):
                for next_state, prob_next_state in [mdp.next_state(state, action)]:
                    reward = mdp.reward(state, action, next_state)
                    value_next_state = discount * V[next_state]
                    v += prob_action * prob_next_state * (reward + value_next_state)
            delta = max(delta, np.abs(v - V[state]))
            V[state] = v
        if delta < threshold:
            break
    return V

In [200]:
# policy is uniformly random action
random_policy = (np.ones((gridworld.n_states, gridworld.n_actions))
                / gridworld.n_actions)
random_policy[gridworld.terminal_states] = 0
print(random_policy)

[[0.   0.   0.   0.  ]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.25 0.25 0.25 0.25]
 [0.   0.   0.   0.  ]]


In [201]:
state_values = evaluate_policy(random_policy, gridworld,
                               discount=1, threshold=0.0001)
print(state_values)

[  0.       -13.999311 -19.99901  -21.99891  -13.999311 -17.999155
 -19.999083 -19.999092 -19.99901  -19.999083 -17.999226 -13.999422
 -21.99891  -19.999092 -13.999422   0.      ]


In [202]:
# values in grid form
grid_values = state_values.reshape(gridworld.size)
grid_values = np.around(grid_values, decimals=1)
print(grid_values)

[[  0. -14. -20. -22.]
 [-14. -18. -20. -20.]
 [-20. -20. -18. -14.]
 [-22. -20. -14.   0.]]


In [203]:
# for (4, 4) case we know the true values
if gridworld.size == (4,4):
    true_values = [[  0, -14, -20, -22],
                   [-14, -18, -20, -20],
                   [-20, -20, -18, -14],
                   [-22, -20, -14,   0]]
    assert np.allclose(true_values, grid_values), 'oh no!'
    print('Hurray!')

Hurray!


### Policy improvement

In [204]:
def improve_policy(policy, mdp, state_values, discount, until_stable=True):
    """Policy improvement algorithm given state-values under current policy."""
    while True:
        is_stable = True
        Q = np.zeros_like(policy)
        
        for state in mdp.nonterminal_states:
            # get action-values for state
            for action in range(mdp.n_actions):
                for next_state, prob_next_state in [mdp.next_state(state, action)]:
                    reward = mdp.reward(state, action, next_state)
                    value_next_state = discount * state_values[next_state]
                    Q[state, action] += prob_next_state * (
                        reward + value_next_state)
                    
            # set submaximal actions to zero and renormalise
            submax_actions = np.flatnonzero(Q[state] != np.max(Q[state]))
            new_policy = policy[state]
            new_policy[submax_actions] = 0
            new_policy /= sum(new_policy)
            
            if np.any(new_policy != policy[state]):
                is_stable = False
            policy[state] = new_policy
        if not until_stable:
            return policy, is_stable
        if is_stable:
            return policy

In [205]:
policy = improve_policy(random_policy, gridworld, state_values, 1)
np.argmax(policy, axis=1).reshape(gridworld.size)

array([[0, 2, 2, 2],
       [0, 0, 2, 1],
       [0, 0, 1, 1],
       [0, 3, 3, 0]])

### Policy iteration

In [206]:
def policy_iteration(mdp, discount, threshold):
    """Policy iteration algorithm given mdp."""
    # initial random policy
    policy = np.ones((mdp.n_states, mdp.n_actions)) / mdp.n_actions
    policy[mdp.terminal_states] = 0
    
    while True:
        state_values = evaluate_policy(policy, mdp, discount, threshold)
        policy, is_stable = improve_policy(
                policy, mdp, state_values, discount, until_stable=False)
        if is_stable:
            break
    return policy

In [207]:
pi_star = policy_iteration(gridworld, 1, 0.001)
np.argmax(pi_star, axis=1).reshape(gridworld.size)

array([[0, 2, 2, 2],
       [0, 0, 2, 1],
       [0, 0, 1, 1],
       [0, 3, 3, 0]])

### Value iteration

In [250]:
def value_iteration(mdp, discount, threshold):
    """Value iteration algorithm given mdp."""
    # compute state values
    V = np.zeros_like(mdp.states, dtype=np.float32)
    while True:
        delta = 0.0
        for state in mdp.nonterminal_states:
            returns = []
            for action in range(mdp.n_actions):
                v = 0
                for next_state, prob_next_state in [mdp.next_state(state, action)]:
                    reward = mdp.reward(state, action, next_state)
                    value_next_state = discount * V[next_state]
                    v += prob_next_state * (reward + value_next_state)
                returns.append(v)
            v_max = max(returns)
            delta = max(delta, np.abs(v_max - V[state]))
            V[state] = v_max
        if delta < threshold:
            break
            
    # compute policy
    policy = np.zeros((mdp.n_states, mdp.n_actions))
    for state in mdp.nonterminal_states:
        returns = []
        for action in range(mdp.n_actions):
            v = 0
            for next_state, prob_next_state in [mdp.next_state(state, action)]:
                reward = mdp.reward(state, action, next_state)
                value_next_state = discount * V[next_state]
                v += prob_next_state * (reward + value_next_state)
            returns.append(v)
        a_max = np.argmax(returns)
        policy[state, a_max] = 1
    return policy

In [266]:
pi_star = value_iteration(gridworld, 1, 0.001)
np.argmax(pi_star, axis=1).reshape(gridworld.size)

array([[0, 2, 2, 1],
       [0, 0, 0, 1],
       [0, 0, 1, 1],
       [0, 3, 3, 0]])

In [272]:
gridworld = Gridworld((8,8))
%timeit value_iteration(gridworld, 1, 0.001)
%timeit policy_iteration(gridworld, 1, 0.001)

81.5 ms ± 4.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
5.44 s ± 704 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
