In [None]:
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
%matplotlib inline

import time
import seaborn
import matplotlib.pyplot as plt
import numpy as np
import math

from gym.envs.toy_text.frozen_lake import LEFT, RIGHT, DOWN, UP
from gym.envs.toy_text import frozen_lake, discrete

import gym
from gym.envs.registration import register

register(
    id='D4x4-FrozenLake-v0',
    entry_point='gym.envs.toy_text.frozen_lake:FrozenLakeEnv',
    kwargs={'map_name': '4x4',
            'is_slippery': False}) # Note: You have to solve by changing this bool to True.


""" 
    env: gym.core.Environment
    Environment to play on.

    env.P: dictionary
    It is from gym.core.Environment
    P[state][action] is tuples with (probability, nextstate, reward, terminal)
    
    env.nS: int
    count of states 
    
    env.nA: int
    count of actions available
    
    action_space: discrete
                LEFT = 0
                DOWN = 1
                RIGHT = 2
                UP = 3
    ENVIRONMENT: 
                "SFFF",
                "FHFH",
                "FFFH",
                "HFFG"
"""

def print_policy(policy, action_names):
    """ 
    Print and return the policy in human-readable format.
    """
    str_policy = policy.astype('str')
    for action_num, action_name in action_names.items():
        np.place(str_policy, policy == action_num, action_name)
    
    print(str_policy[0:4])
    print(str_policy[4:8])
    print(str_policy[8:12])
    print(str_policy[12:16])
    
    return str_policy

action_names = {LEFT: 'LEFT', RIGHT: 'RIGHT', DOWN: 'DOWN', UP: 'UP'}

In [None]:
env = gym.make('D4x4-FrozenLake-v0')
grid = 4
gamma = 0.9 # Change this to play with it

In [None]:
def fancy_visual(value_func,policy_str):
    f, ax = plt.subplots(figsize=(11, 9))
    cmap = seaborn.diverging_palette(220, 10, as_cmap=True)
    reshaped=np.reshape(value_func,(grid,grid))
    seaborn.heatmap(reshaped, cmap=cmap, vmax=1.1,
                square=True, xticklabels=grid+1, yticklabels=grid+1,
                linewidths=.5, cbar_kws={"shrink": .5}, ax=ax, annot=True, fmt="f")
    counter=0
    for j in range(0, 4):
        for i in range(0, 4):
            if policy_str[counter]=="DOWN":
                plt.text(i+0.5, j+0.7, u'\u2193', fontsize=12)
            elif policy_str[counter]=="UP":
                plt.text(i+0.5, j+0.7, u'\u2191', fontsize=12)
            elif policy_str[counter]=="LEFT":
                plt.text(i+0.5, j+0.7, u'\u2190', fontsize=12)
            else:
                plt.text(i+0.5, j+0.7, u'\u2192', fontsize=12)
            counter=counter+1

    plt.title('Heatmap of policy iteration with value function values and directions')


In [None]:
########################################################################
######################### Policy iteration #############################
########################################################################


def policy_evaluation(env, gamma, policy, value_func_old, max_iterations=int(1e3), tol=1e-3):
    """
        Evaluate the value of a policy.
        See section 4.1 of Reinforcement Learning: An Introduction (Adaptive Computation and Machine Learning) by Sutton and Barto
    """
    pass

    return



def policy_improvement(env, gamma, value_func, policy):
    """
      Given a policy and value function, improve the policy.
      Returns true if policy is unchanged. Also returns the new policy.
      See section 4.2 of Reinforcement Learning: An Introduction (Adaptive Computation and Machine Learning) by Sutton and Barto
    """
    pass

    return


def policy_iteration(env, gamma, max_iterations=int(1e3), tol=1e-3):
    """
       Runs policy iteration.
       Returns optimal policy, value function, number of policy
       improvement iterations, and number of value iterations.
       See section 4.3 of Reinforcement Learning: An Introduction (Adaptive Computation and Machine Learning) by Sutton and Barto
    """

    # Hint: Plot ||V_{\pi_k}-V_{\pi_{k-1}}|| here!

    # Return these parameters:
    # Optimal Policy, corresponding value function values. Remember, there are 16 values.
    return policy, value_func


print("Doing Policy Iteration")
start_time=time.time()
policy, value_func=policy_iteration(env,gamma)
print("Total time taken: "+str((time.time()-start_time)))
print("Policy:")
policy_str=print_policy(policy,action_names) # Prints and gets the policy in Human readable format
fancy_visual(value_func,policy_str) # Takes in Optimal Policy and corresponding value function values to plot the grid

1. What do you infer from the convergence plots?
2. What do the numbers in the heatmap mean? (Hint: Check the code)

In [None]:
########################################################################
#################### Final policy animation ############################
########################################################################

flag=input("\nEnter 'Y' if you want to see the final animation of the policy achieved. Else enter something else.\n")
if flag=="Y" or flag=="y": print("Final Policy Animation")
def run_policy(env,gamma,policy):
    initial_state = env.reset()
    env.render()
    current_state = initial_state
    while True:
        nextstate, reward, done, debug_info = env.step(policy[current_state])
        env.render()

        if done:
            break

        current_state=nextstate
        time.sleep(1)

if flag=="Y" or flag=="y": run_policy(env,gamma,policy)
