In [60]:
import io
import sys

import numpy as np
import matplotlib

from gym.envs.toy_text import discrete
from collections import defaultdict

In [9]:
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3

In [10]:
class GridworldEnv(discrete.DiscreteEnv):
    """
    Grid World environment from Sutton's Reinforcement Learning book chapter 4.
    You are an agent on an MxN grid and your goal is to reach the terminal
    state at the top left or the bottom right corner.
    For example, a 4x4 grid looks as follows:
    T  o  o  o
    o  x  o  o
    o  o  o  o
    o  o  o  T
    x is your position and T are the two terminal states.
    You can take actions in each direction (UP=0, RIGHT=1, DOWN=2, LEFT=3).
    Actions going off the edge leave you in your current state.
    You receive a reward of -1 at each step until you reach a terminal state.
    """

    metadata = {'render.modes': ['human', 'ansi']}

    def __init__(self, shape=[4,4]):
        if not isinstance(shape, (list, tuple)) or not len(shape) == 2:
            raise ValueError('shape argument must be a list/tuple of length 2')

        self.shape = shape

        nS = np.prod(shape)
        nA = 4

        MAX_Y = shape[0]
        MAX_X = shape[1]

        P = {}
        grid = np.arange(nS).reshape(shape)
        it = np.nditer(grid, flags=['multi_index'])

        while not it.finished:
            s = it.iterindex
            y, x = it.multi_index

            # P[s][a] = (prob, next_state, reward, is_done)
            P[s] = {a : [] for a in range(nA)}

            is_done = lambda s: s == 0 or s == (nS - 1)
            reward = 0.0 if is_done(s) else -1.0

            # We're stuck in a terminal state
            if is_done(s):
                P[s][UP] = [(1.0, s, reward, True)]
                P[s][RIGHT] = [(1.0, s, reward, True)]
                P[s][DOWN] = [(1.0, s, reward, True)]
                P[s][LEFT] = [(1.0, s, reward, True)]
            # Not a terminal state
            else:
                ns_up = s if y == 0 else s - MAX_X
                ns_right = s if x == (MAX_X - 1) else s + 1
                ns_down = s if y == (MAX_Y - 1) else s + MAX_X
                ns_left = s if x == 0 else s - 1
                P[s][UP] = [(1.0, ns_up, reward, is_done(ns_up))]
                P[s][RIGHT] = [(1.0, ns_right, reward, is_done(ns_right))]
                P[s][DOWN] = [(1.0, ns_down, reward, is_done(ns_down))]
                P[s][LEFT] = [(1.0, ns_left, reward, is_done(ns_left))]

            it.iternext()

        # Initial state distribution is uniform
        isd = np.ones(nS) / nS

        # We expose the model of the environment for educational purposes
        # This should not be used in any model-free learning algorithm
        self.P = P

        super(GridworldEnv, self).__init__(nS, nA, P, isd)

    def _render(self, mode='human', close=False):
        """ Renders the current gridworld layout
         For example, a 4x4 grid with the mode="human" looks like:
            T  o  o  o
            o  x  o  o
            o  o  o  o
            o  o  o  T
        where x is your position and T are the two terminal states.
        """
        if close:
            return

        outfile = io.StringIO() if mode == 'ansi' else sys.stdout

        grid = np.arange(self.nS).reshape(self.shape)
        it = np.nditer(grid, flags=['multi_index'])
        while not it.finished:
            s = it.iterindex
            y, x = it.multi_index

            if self.s == s:
                output = " x "
            elif s == 0 or s == self.nS - 1:
                output = " T "
            else:
                output = " o "

            if x == 0:
                output = output.lstrip()
            if x == self.shape[1] - 1:
                output = output.rstrip()

            outfile.write(output)

            if x == self.shape[1] - 1:
                outfile.write("\n")

            it.iternext()

In [11]:
env = GridworldEnv()

# 1. On-policy MC control

In [12]:
def make_epsilon_greedy_policy(Q, epsilon, nA):
    """
    Creates an epsilon-greedy policy based on a given Q-function and epsilon.
    
    Takes the observation as an argument and returns
    the probabilities for each action in the form of a numpy array of length nA.
    
    """
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA
        best_action = np.argmax(Q[observation])
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

In [56]:
def mc_control_epsilon_greedy(env, num_episodes, discount_factor=1.0, epsilon=0.1):
    """
    Monte Carlo Control using Epsilon-Greedy policies.
    Finds an optimal epsilon-greedy policy.
    
    Returns:
        A tuple (Q, policy).
        Q is a dictionary mapping state -> action values.
        policy is a function that takes an observation as an argument and returns
        action probabilities
    """
    
    # Keeping track of sum and count of returns for each state for later average calculation.
    returns_sum = defaultdict(float)
    returns_count = defaultdict(float)
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)
    
    for i_episode in range(1, num_episodes + 1):
        
        if i_episode % 500 == 0:
            print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
            sys.stdout.flush()

        # creating an episode (an array of (state, action, reward) tuples)
        episode = []
        state = env.reset()
        for t in range(100):
            probs = policy(state)
            action = np.random.choice(np.arange(len(probs)), p=probs)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state

        # Find all (state, action) pairs we've visited in this episode
        
#         print(episode)
        sa_in_episode = set([(x[0], x[1]) for x in episode])
        
        for state, action in sa_in_episode:
            sa_pair = (state, action)
            
            # Find the first occurance of the (state, action) pair in the episode
            first_occurence_idx = next(i for i, x in enumerate(episode)
                                       if x[0] == state and x[1] == action)
            
            # Sum up all rewards since the first occurance
            G = sum([x[2]*(discount_factor**i) for i, x in enumerate(episode[first_occurence_idx:])])
            
            # Calculate average return for this state over all sampled episodes
            returns_sum[sa_pair] += G
            returns_count[sa_pair] += 1.0
            Q[state][action] = returns_sum[sa_pair] / returns_count[sa_pair]
        
        # The policy is improved implicitly by changing the Q dictionary
    
    return Q, policy

In [74]:
Q, policy = mc_control_epsilon_greedy(env, num_episodes=500000, epsilon=0.1)

Episode 500000/500000.

In [75]:
Q

defaultdict(<function __main__.mc_control_epsilon_greedy.<locals>.<lambda>()>,
            {15: array([0., 0., 0., 0.]),
             4: array([-1.        , -3.26460072, -3.29984301, -2.18683565]),
             5: array([-2.2254395 , -4.36114044, -4.36159346, -2.14337241]),
             1: array([-2.17423313, -3.34546959, -3.28265348, -1.        ]),
             2: array([-3.39568345, -4.50450958, -4.32522796, -2.14241506]),
             6: array([-3.27774436, -3.40570175, -3.37681159, -3.34276387]),
             3: array([-4.57028571, -4.57079646, -3.27943782, -3.47486631]),
             7: array([-4.43246311, -3.29567723, -2.14101293, -4.30432403]),
             13: array([-4.31723338, -2.14182553, -3.38035363, -4.39157325]),
             9: array([-3.39932127, -3.3877095 , -3.28208942, -3.37024973]),
             14: array([-3.27928202, -1.        , -2.14427552, -3.2758512 ]),
             10: array([-4.37921348, -2.14439054, -2.18821096, -4.36067298]),
             12: array([-3.31

In [76]:
# create value function from action-value function by picking the best action at each state
V = defaultdict(float)

for state, actions in Q.items():
    action_value = np.max(actions)
    V[state] = action_value

In [77]:
V

defaultdict(float,
            {15: 0.0,
             4: -1.0,
             5: -2.1433724075743914,
             1: -1.0,
             2: -2.1424150552160732,
             6: -3.277744360902256,
             3: -3.279437820955584,
             7: -2.1410129250224714,
             13: -2.141825525211195,
             9: -3.2820894163353262,
             14: -1.0,
             10: -2.144390543102487,
             12: -3.278265327663681,
             8: -2.1431812681510163,
             11: -1.0,
             0: 0.0})

# 2. Off-policy MC control

In [68]:
def create_random_policy(nA):
    """
    Creates a random policy function.
    """
    A = np.ones(nA, dtype=float) / nA
    def policy_fn(observation):
        return A
    return policy_fn

In [69]:
def create_greedy_policy(Q):
    """
    Creates a greedy policy based on Q values.
    """
    
    def policy_fn(state):
        A = np.zeros_like(Q[state], dtype=float)
        best_action = np.argmax(Q[state])
        A[best_action] = 1.0
        return A
    return policy_fn

In [70]:
def mc_control_importance_sampling(env, num_episodes, behavior_policy, discount_factor=1.0):
    """
    Monte Carlo Control Off-Policy Control using Weighted Importance Sampling.
    Finds an optimal greedy policy.
    """
    
    # the final action-value function.
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    # cumulative denominator of the weighted importance sampling formula (across all episodes)
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    
    # greedy target policy
    target_policy = create_greedy_policy(Q)
        
    for i_episode in range(1, num_episodes + 1):
        if i_episode % 500 == 0:
            print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
            sys.stdout.flush()
            
        # create an episode
        episode = []
        state = env.reset()
        for t in range(100):
            
            # action sampling
            probs = behavior_policy(state)
            action = np.random.choice(np.arange(len(probs)), p=probs)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state
        
        # sum of discounted returns
        G = 0.0
        
        # the importance sampling ratio (the weights of the returns)
        W = 1.0
        
        # reveresed step iteration
        for t in range(len(episode))[::-1]:
            state, action, reward = episode[t]
            
            # update the total reward since step t
            G = discount_factor * G + reward
            
            # update weighted importance sampling formula denominator
            C[state][action] += W
            
            # target policy incremental improval
            Q[state][action] += (W / C[state][action]) * (G - Q[state][action])
            
            # should the action taken not be the right action (of the target policy), the cycle is broken.
            if action != np.argmax(target_policy(state)):
                break
                
            # update the weights
            W = W * 1./behavior_policy(state)[action]
        
    return Q, target_policy

In [71]:
env = GridworldEnv()

In [72]:
random_policy = create_random_policy(env.action_space.n)
Q, policy = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy)

Episode 500000/500000.

In [73]:
Q

defaultdict(<function __main__.mc_control_importance_sampling.<locals>.<lambda>()>,
            {4: array([-1.        , -2.9990933 , -2.06666667, -1.99978464]),
             11: array([-2.26666667, -1.99979441, -1.        , -2.01149425]),
             1: array([-1.99971154, -1.82051282, -2.99883459, -1.        ]),
             0: array([0., 0., 0., 0.]),
             14: array([-2.24      , -1.        , -1.99962826, -1.94117647]),
             15: array([0., 0., 0., 0.]),
             8: array([-1.99971567, -1.8       , -1.77876106, -2.20588235]),
             12: array([-2.04347826, -1.6       , -1.97175141, -1.79279279]),
             3: array([-2.        , -1.67326733, -2.        , -1.89256198]),
             6: array([-2.14685315, -1.68421053, -1.77669903, -2.99897825]),
             9: array([-2.9991589 , -1.71111111, -1.71910112, -2.03225806]),
             10: array([-1.45714286, -1.99966875, -1.99965809, -2.15702479]),
             7: array([-2.0877193 , -1.8358209 , -1.9997475