#### Sutton and Barto, Reinforcement Learning 2nd. Edition, page 130.
![Sutton and Barto, Reinforcement Learning 2nd. Edition.](./figures/QLearning.png)

Q-learning (off-policy TD control) for estimating π

In [1]:
import numpy as np
from rlgridworld.standard_grid import create_standard_grid

Code to play game and return Q values -- Q(state,action)

In [2]:
def play_game(gw, Q):
    
    epsilon = 0.05 # probability of exploration
    gamma = 0.9 # discount factor for future rewards
    alpha = 0.1 # Q update fraction

    # game starting state
    state = (0, 0)
    converged = False
    while not converged:
        # select best action at the state
        action, _ = max_dict(Q[state])
        # get all valid actions at the state
        all_actions = gw.valid_decisions(state)
        # choose a random action with probability epsilon
        action = random_action(action, all_actions, epsilon)
        # get reward for action
        reward = gw.get_reward_for_action(state, action)
        stateprime = move(state, action)
        iprime, jprime = stateprime
        if not gw.is_terminal(stateprime):
            _, destvalue = max_dict(Q[stateprime])
            Q[state][action] = Q[state][action] + alpha*(reward + gamma*destvalue - Q[state][action])
            state = stateprime
        if gw.is_terminal(stateprime):
            Q[state][action] = Q[state][action] + alpha*(reward - Q[state][action]) 
            converged = True
    return Q

def move(state, action):  # only valid actions at states are sent to move
    i, j = state
    if action == 'left':
        j = j-1
    if action == 'right':
        j = j+1
    if action == 'down':
        i = i-1
    if action == 'up':
        i = i+1
    return (i, j)

def random_action(action, all_actions, epsilon):
    p = np.random.random_sample()
    if p < (1 - epsilon):
        return action
    else:
        return np.random.choice(all_actions)

def max_dict(d):
    # returns the argmax (key) and max (value) from a dictionary
    # put this into a function since we are using it so often
    max_key = None
    max_val = float('-inf')
    for k, v in d.items():
        if v > max_val:
            max_val = v
            max_key = k
    return max_key, max_val

def init_Q(gw):
    # initialize Q values
    Q = {}
    for i in range(0, gw.M):
        for j in range(0, gw.N):
            state = (i,j)
            if not gw.is_barrier(state) and not gw.is_terminal(state):
                Q[state] = {}
                all_actions = gw.valid_decisions(state)
                for a in all_actions:
                    Q[state][a] = 0
    return Q


Create the standard grid

In [3]:
gw = create_standard_grid()

Initialize the Q dictionary

In [4]:
Q = init_Q( gw )

See what is in the initial Q dictionary. The tuples (the pair of numbers at the beginning of each line) are the dictionary keys. Each dictionary value is another dictionary. For the second dictionary, the decisions are the dictionary keys. The values in this dictionary are the Q values for each action at the designated state.

In [5]:
Q

{(0, 0): {'right': 0, 'up': 0},
 (0, 1): {'left': 0, 'right': 0},
 (0, 2): {'left': 0, 'right': 0, 'up': 0},
 (0, 3): {'left': 0, 'up': 0},
 (1, 0): {'down': 0, 'up': 0},
 (1, 2): {'right': 0, 'down': 0, 'up': 0},
 (2, 0): {'right': 0, 'down': 0},
 (2, 1): {'left': 0, 'right': 0},
 (2, 2): {'left': 0, 'right': 0, 'down': 0}}

Play one iteration of the game

In [6]:
Q = play_game(gw, Q)

See what the Q values are. Note the effect of the alpha factor in the updates of the Q values.

In [7]:
Q

{(0, 0): {'right': 0.0, 'up': 0.0},
 (0, 1): {'left': 0.0, 'right': 0.0},
 (0, 2): {'left': 0.0, 'right': 0.0, 'up': 0.0},
 (0, 3): {'left': 0.0, 'up': 0},
 (1, 0): {'down': 0.0, 'up': 0.0},
 (1, 2): {'right': -0.1, 'down': 0, 'up': 0},
 (2, 0): {'right': 0.0, 'down': 0.0},
 (2, 1): {'left': 0.0, 'right': 0.0},
 (2, 2): {'left': 0.0, 'right': 0, 'down': 0}}

Play another iteration of the game and see what the Q values are

In [8]:
Q = play_game(gw, Q)

In [9]:
Q

{(0, 0): {'right': 0.0, 'up': 0.0},
 (0, 1): {'left': 0.0, 'right': 0.0},
 (0, 2): {'left': 0.0, 'right': 0.0, 'up': 0.0},
 (0, 3): {'left': 0.0, 'up': -0.1},
 (1, 0): {'down': 0.0, 'up': 0.0},
 (1, 2): {'right': -0.1, 'down': 0.0, 'up': 0.0},
 (2, 0): {'right': 0.0, 'down': 0.0},
 (2, 1): {'left': 0.0, 'right': 0.0},
 (2, 2): {'left': 0.0, 'right': 0, 'down': 0.0}}

Play the game 10000 times and see what the Q values are

In [10]:
for _ in range(10000):
    Q = play_game(gw, Q)

In [11]:
Q

{(0, 0): {'right': 0.5314409999997758, 'up': 0.6560999999999979},
 (0, 1): {'left': 0.5904899999999977, 'right': 0.27195020511481655},
 (0, 2): {'left': 0.5314409996065254,
  'right': 0.004304247247376447,
  'up': 0.08178868036170878},
 (0, 3): {'left': 0.09087191542298795, 'up': -0.1},
 (1, 0): {'down': 0.5904899999998027, 'up': 0.7289999999999983},
 (1, 2): {'right': -0.271, 'down': 0.47829689046419815, 'up': 0.0},
 (2, 0): {'right': 0.8099999999999987, 'down': 0.6560999999697452},
 (2, 1): {'left': 0.7289999999997339, 'right': 0.899999999999999},
 (2, 2): {'left': 0.8099998051361166,
  'right': 0.9999999999999996,
  'down': 0.43046708985120546}}

Extract value function and policy function from the Q table.

In [12]:
policy = {}
for i in range(gw.M):
    for j in range(gw.N):
        state = (i,j)
        if gw.is_barrier(state):
            policy[state] = ''
        if gw.is_terminal(state):
            policy[state] = ''
        if not gw.is_barrier(state) and not gw.is_terminal(state):
            action, value = max_dict(Q[state])
            gw.set_value(state, value)
            policy[state] = action

In [13]:
gw.print_values()
gw.print_policy(policy)

-------------------------------------
|   0.81 |   0.90 |   1.00 |   0.00 |
-------------------------------------
|   0.73 |   0.00 |   0.48 |   0.00 |
-------------------------------------
|   0.66 |   0.59 |   0.53 |   0.09 |
-------------------------------------
-------------------------------------
|  Right |  Right |  Right |        |
-------------------------------------
|     Up |        |   Down |        |
-------------------------------------
|     Up |   Left |   Left |   Left |
-------------------------------------
