This notebook implements Dynamic Programming Algorithm for MDP probrem.

In [121]:
import random

class DP:
    def compute_value(self, env, policy, values, state):
        value = 0
        gamma = env.gamma()
        for action, pi in policy[state].items():
            for (next_state, reward, prob) in env.step_enum(state=state, action=action):
                value += pi * prob * (reward + gamma * values[next_state])
        return value

    def compute_value_from_qvalues(self, env, policy, qvalues, state):
        value = 0
        for action, pi in policy[state].items():
            value += pi * qvalues[state][action]
        return value

    def compute_qvalue(self, env, policy, qvalues, state, action):
        value = 0
        gamma = env.gamma()
        for (next_state, reward, prob) in env.step_enum(state=state, action=action):
            value += prob * reward
            for next_action, pi in policy[next_state].items():
                value += prob * gamma * pi * qvalues[next_state][next_action]
        return value

    def compute_qvalue_from_values(self, env, policy, values, state, action):
        value = 0
        for (next_state, reward, prob) in env.step_enum(state=state, action=action):
            value += prob * (reward + values[next_state])
        return value

    def make_decision(self, policy, state):
        sum = 0
        actions = []
        for action, pi in policy[state].items():
            sum += pi
            actions.append((action, sum))
        sample = random.random() * sum
        for (action, acc) in actions:
            if sample < acc:
                return action
        return None

    def evaluate(self, env, policy, values):
        max_delta = 0
        for state in range(len(values)):
            new_value = self.compute_value(env, policy, values, state)
            delta = abs(new_value - values[state])
            max_delta = max(delta, max_delta)
            values[state] = new_value
        return max_delta

    def qevaluate(self, env, policy, qvalues):
        max_delta = 0
        for state in range(len(qvalues)):
            for action in qvalues[state]:
                new_value = self.compute_qvalue(env, policy, qvalues, state, action)
                delta = abs(new_value - qvalues[state][action])
                max_delta = max(delta, max_delta)
                qvalues[state][action] = new_value
        return max_delta

    def policy_improvment(self, env, policy, values):
        stable = True
        for state in range(len(values)):
            action = self.make_decision(policy, state)
            if action is None:
                continue
            max_value = None
            max_action = action
            for action, prob in policy[state].items():
                value = self.compute_qvalue_from_values(env, policy, values, state, action)
                if max_value is None or value > max_value:
                    max_value = value
                    max_action = action
            policy[state] = {}
            policy[state][max_action] = 1
            if action != max_action:
                stable = False
        return stable

    def policy_iteration(self, env):
        policy = env.default_policy()
        values = [0] * env.num_states()
        theta = 0.001
        stable = False
        times = 0
        while not stable:
            # policy evaluation
            delta = theta+1
            while delta > theta:
                delta = self.evaluate(env, policy, values)
            # policy improvment
            stable = self.policy_improvment(env, policy, values)
            times += 1
            print("policy_iteration: times=%d" % times)
        return (policy, values)

Run examples

Example 1: 4x4 gridworld shown below:

|  | 1 | 2 | 3 |
|---|---|---|---|
| 4 | 5 | 6 | 7 |
| 8 | 9 | 10| 11|
| 12| 13| 14|   |

1. Rt = 1 on all transitions.
2. The nonterminal states S = {1,2,...,14}.
3. There are four actions A = {up, down, left, right}, which deterministically cause the corresponding state transitions, except that actions would take the agent off the grid in fact leave the state unchanged.
4. This is an undiscounted, episodic task, and the terminal states are the empty grids.

Values for equiprobable random policy:

|0|-14|-20|-22|
|---|---|---|---|
|-14|-18|-20|-20|
|-20|-20|-18|-14|
|-22|-20|-14|0|

In [129]:
from enum import Enum


class GridWorld4x4:
    ROWS = 4
    COLUMNS = ROWS

    class Action(Enum):
        UP = 1
        DOWN = 2
        LEFT = 3
        RIGHT = 4

    def gamma(self):
        return 1

    def num_states(self):
        return self.ROWS * self.COLUMNS

    def step_enum(self, state, action):
        column = int(state % self.COLUMNS)
        row = int((state - column) / self.COLUMNS)
        reward = -1
        if action == self.Action.LEFT:
            return [(state if column == 0 else self.make_state(row, column-1), reward, 1)]
        elif action == self.Action.RIGHT:
            return [(state if column+1 == self.COLUMNS else self.make_state(row, column+1), reward, 1)]
        elif action == self.Action.UP:
            return [(state if row == 0 else self.make_state(row-1, column), reward, 1)]
        else:
            return [(state if row+1 == self.ROWS else self.make_state(row+1, column), reward, 1)]
    
    def make_state(self, row, column):
        return row * self.COLUMNS + column

    def default_policy(self):
        policy = {}
        prob = 1.0 / 4.0
        for state in range(self.num_states()):
            if state == 0 or state == 15:
                policy[state] = {}
            else:
                policy[state] = {
                    self.Action.UP: prob,
                    self.Action.DOWN: prob,
                    self.Action.LEFT: prob,
                    self.Action.RIGHT: prob
                }
        return policy

    def print_values(self, title, values, formatter = None):
        print(title)
        for i in range(self.ROWS):
            print('|', end='')
            for j in range(self.COLUMNS):
                state = self.make_state(i, j)
                if formatter is not None:
                    print("%s|" % formatter(values[state]), end='')
                else:
                    print("%g|" % round(values[state], 2), end='')
            print()
            if (i == 0):
                print('|', end='')
                for j in range(self.COLUMNS):
                    print('---|', end='')
                print()

In [131]:
def gridworld_evaluate():
    env = GridWorld4x4()
    agent = DP()
    values = [0] * (env.num_states())
    max_times = 100000
    times = 0
    theta = 0.0001
    delta = 1
    policy = env.default_policy()
    while delta > theta and times < max_times:
        times += 1
        delta = agent.evaluate(env, policy, values)
    env.print_values("evaluate values:", values)

    delta = 1
    times = 0
    qvalues = []
    for state in range(env.ROWS * env.COLUMNS):
        qvalues.append({action: 0 for action in policy[state]})
    while delta > theta and times < max_times:
        times += 1
        delta = agent.qevaluate(env, policy, qvalues)
    print( "evaluate qvalues: delta=%.6f, times=%d, q(11,down)=%g, q(7,down)=%g" % (
        delta,
        times,
        round(qvalues[11][env.Action.DOWN], 2),
        round(qvalues[7][env.Action.DOWN], 2)
    ))

def format_policy(policy):
    for action in policy:
        if action == GridWorld4x4.Action.LEFT:
            return "←"
        elif action == GridWorld4x4.Action.RIGHT:
            return "→"
        elif action == GridWorld4x4.Action.UP:
            return "↑"
        elif action == GridWorld4x4.Action.DOWN:
            return "↓"
        else:
            return ""
    return ""

def gridworld_policy_improvment():
    env = GridWorld4x4()
    agent = DP()
    (policy, values) = agent.policy_iteration(env)
    env.print_values("policy improvment policy:", policy, format_policy)
    env.print_values("policy improvment values:", values)

gridworld_policy_improvment()

policy_iteration: times=1
policy_iteration: times=2
policy improvment policy:
||←|←|←|
|---|---|---|---|
|↑|←|←|↓|
|↑|↑|↓|↓|
|↑|→|→||
policy improvment values:
|0|-1|-2|-3|
|---|---|---|---|
|-1|-2|-3|-2|
|-2|-3|-2|-1|
|-3|-2|-1|0|


policy improvment policy:

||←|←|←|
|---|---|---|---|
|↑|←|←|↓|
|↑|↑|↓|↓|
|↑|→|→||

policy improvment values:

|0|-1|-2|-3|
|---|---|---|---|
|-1|-2|-3|-2|
|-2|-3|-2|-1|
|-3|-2|-1|0|