### Assignment : Week 2
## Finding best policies in simple MDPs

Great work making the MDPs in Week 1!

In this assignment, we'll use the simplest RL techniques - Policy and Value iteration to find the best policies (which maximize the discounted total reward) in our MDPs from last week.

Feel free to use your own MDPs, or import them from the OpenAI Gym library.

You can start this assignment during/after reading Grokking Ch-3.

Let us recall the equation to find the value function of agent's states under a policy $\pi$ -
$$v_{\pi}(s) = \sum _{a} \pi(a|s) ~ \left( ~ \sum _{s', r} ~ p(s', r | s, a) ~ \left[r + \gamma v_{\pi}(s') \right] ~ \right)$$

We can observe that the value function $v_{\pi}$ has a lot of circular dependencies on different states. 

To solve such equations, one of the ways is to iteratively calculate the RHS and replace the LHS by it until the $v_{\pi}(s)$ values start to converge. 

The point of convergence makes all the equations simultaneously true and hence is the required solution.

Let us calculate the value functions for some policies in the MDPs we created last week.

## Environment 0 - Bandit Walk

Again, we consider the BW environment on Page 39.

Let's consider what seems to be the most natural policy - always go Right.

This environment is so simple, that we can simply calculate the value functions by hand.

Note that by convention for the terminal states, 
$$v_{\pi}(0) = v_{\pi}(2) = 0$$

Now, 
$$v_{\pi}(1) = 1 + \gamma \cdot v_{\pi}(2) = 1$$

Note both the summations just have one term due to the deterministic nature of the environment and the policy (check which summation was corresponding to which stochastic variable)

## Environment 1 - Slippery Walk

Let's now try to solve the SWF environment from Page 67 for the naturally adversarial policy - always go Left.

Since we have 5 coupled equations for states 1-5 with 5 unknown variables, we'll use Python to bruteforce the solution.

To align with Grokking, let us consider an unusual $\gamma = 1$.

In [5]:
# Step 0 is to import stuff
!pip install gym
!pip install git+https://github.com/mimoralea/gym-walk#egg=gym-walk
# !pip install gym_walk

import gym
import gym_walk
import numpy as np
from gym.envs.toy_text.frozen_lake import generate_random_map

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Collecting gym-walk
  Cloning https://github.com/mimoralea/gym-walk to /tmp/pip-install-yvc1j8wg/gym-walk_100f008d3fb547389d19c3247058aef3
  Running command git clone --filter=blob:none --quiet https://github.com/mimoralea/gym-walk /tmp/pip-install-yvc1j8wg/gym-walk_100f008d3fb547389d19c3247058aef3
  Resolved https://github.com/mimoralea/gym-walk to commit 5999016267d6de2f5a63307fb00dfd63de319ac1
  Preparing metadata (setup.py) ... [?25ldone

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m

In [6]:
# Step 1 is to get the MDP

env = gym.make('SlipperyWalkFive-v0')
swf_mdp = env.P
# swf_mdp

# Note that in Gym, action "Left" is "0" and "Right" is "1"

In [7]:
# Step 2 is to write the policy



# Or you can do it randomly
# pi = dict()
# for state in mdp:
#     pi[state] = np.random.choice(mdp[state].keys())

In [8]:
# Step 3 is computing the value function for this envi and policy

# Let us start with a random value function

val = dict()
for state in swf_mdp:
    val[state] = np.random.random()

# Since 0 and 6 are terminal states, we know their values are 0

val[0] = 0
val[6] = 0

#Or you could do it randomly, remember to set the terminal states to 0. You can also implement this while evaluating the value function using 
# val = dict()
# for state in mdp:
#     val[state] = np.random.random()
#     if mdp[state][0][0][0] == 0: # if the first action in the first outcome of the first state is 0, then it is a terminal state
#         val[state] = 0

#instead of doing thsi you can simply intialize the value function to 0 for all states 
# for state in swf_mdp:
#   val[state] = 0

In [9]:
def get_new_value_fn(val, mdp, pi, gamma = 1.0):
    
    new_val = dict()
    # Complete this function to get the new value function given the old value function and the policy
    # v(k) = SUM_a SUM_r,s' p(r, s' | s, a) (r+gamma*v(k-1))
    for state in val:
        # only left action possible => sum1 is just one term
        # also since its deterministic only one reward and next state are possible
        # so its just r + Y*v(k-1)
        k = 0
        for t in mdp[state][0]:
            k += t[0]*(t[2] + gamma*(val[t[1]]))
        new_val[state] = k
            
    return new_val

In [10]:
def diff(val, new_val):
    m = 0
    for i in val:
        m = max(m, abs(val[i]-new_val[i]))
    return m        

In [12]:
#Use to above function to get the new value function, also print how many iterations it took to converge
def policy_evaluation(val, mdp, pi, epsilon=1e-10, gamma=1.0):
    count = 0
    # Complete this function to iteratively caluculate the value function until the difference between the new and old value function is less than epsilon
    # Also return the number of iterations it took to converge
    while True:
        new_val = get_new_value_fn(val, mdp, pi, gamma)
        count += 1
        if diff(val, new_val) <= epsilon:
            val = new_val
            break
        val = new_val
    
    return val, count 

# val, c = policy_evaluation(val, swf_mdp, pi)
# print(val, c)

In [13]:
def get_q(val, mdp, gamma):
    q = {}
    for state in mdp:
        q[state] = {}
        for action in mdp[state]:
            q[state][action] = 0
            for thing in mdp[state][action]:
                prob, next_state, reward, done = thing
                q[state][action] += prob*(reward + val[next_state]*gamma*(not done))
                
    return q

In [14]:
# Perform policy improvement using the policy and the value function and return a new policy, the action value function should be a nested dictionary
def policy_improvement(val, mdp, gamma=1.0):
    new_pi = dict()
    q = get_q(val, mdp, gamma)
    # Complete this function to get the new policy given the value function and the mdp
    for state in q:
        max_val = 0
        max_action = 0
        for action in q[state]:
            if q[state][action] >= max_val:
                max_action = action
                max_val = q[state][action]
        new_pi[state] = max_action          
    val = get_new_value_fn(val, mdp, new_pi)
#     q = get_q(val, mdp, gamma)
    return new_pi


In [15]:
def diff_pi(pi1, pi2):
    m = 0
    for i in pi1:
        m = max(m, abs(pi1[i] - pi2[i]))
    return m

In [20]:
# Use the above functions to get the optimal policy and optimal value function and return the total number of iterations it took to converge
# Create a random policy and value function to start with or use the ones defined above
def policy_iteration(mdp, epsilon=1e-10, gamma=1.0):
    pi = {
        0 : 0,
        1 : 0,
        2 : 0,
        3 : 0,
        4 : 0,
        5 : 0,
        6 : 0
    }
    
    val = dict()
    for state in mdp:
        val[state] = np.random.random()

    val[0] = 0
    val[len(val)-1] = 0
    count = 0
    # Complete this function to get the optimal policy and value function and return the total number of iterations it took to converge
    while True:   
        print(pi)
        val, c = policy_evaluation(val, mdp, pi)
        new_policy = policy_improvement(val, mdp, gamma)
        count += 1
        if diff_pi(pi, new_policy) == 0:
            break
        pi = new_policy
    return pi, val, count

pi, val, count = policy_iteration(swf_mdp)
print(pi, val, count)

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
{0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
{0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1} {0: 0.0, 1: 0.002747252786540573, 2: 0.010989011106874468, 3: 0.035714285950012664, 4: 0.10989011024370032, 5: 0.33241758277117284, 6: 0.0} 2


In [21]:
#Now perform value iteration, note that the value function is a dictionary and not a list, also return the number of iterations it took to converge
def value_iteration(mdp, gamma=1.0, epsilon=1e-10):
    val = {s: 0 for s in mdp}
    count = 0
    q = dict()
    # Complete this function to get the optimal policy, optimal value function and return the total number of iterations it took to converge
    
    while True:
        q = {}
        for state in mdp:
            q[state] = {}
            for action in mdp[state]:
                s = 0
                for thing in mdp[state][action]:
                    prob, nex, reward, done = thing
                    s += prob* (reward + gamma*(not done)*val[nex])
                q[state][action] = s
        new_v = {}
        new_pi = {}
        
        for state in q:
            max_q = 0
            max_action = 0
            for action in q[state]:
                if q[state][action] > max_q:
                    max_action = action
                    max_q = q[state][action]
                    
            new_pi[state] = max_action
            new_v[state] = max_q
            
        pi = new_pi
        count += 1
            
        if diff(new_v, val) <= epsilon:
            val = new_v
            break
        
        val = new_v                            
                    
#     pi = {s: max(q[s], key=q[s].get) for s in mdp}
    return pi, val, count
    
print(pi)
print(value_iteration(swf_mdp))

{0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
({0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 0}, {0: 0, 1: 0.667582417090299, 2: 0.8901098896177716, 3: 0.9642857139576353, 4: 0.9890109888469496, 5: 0.9972527471980676, 6: 0}, 122)


## Enviroment 2 - Frozen Lake

Repeat the above steps for the frozen lake environment. Don't create new functions , use the old functions.

You can also write a function `test_policy()` to test your policy after training to find the number of times you reached the goal state

In [44]:
env2 = gym.make('FrozenLake-v1',desc=generate_random_map(size=4))
mdp2 = env2.P

In [1]:
# print(mdp2)

In [45]:
pi1, val1, count1 = policy_iteration(mdp2)
pi2, val2, count2 = value_iteration(mdp2)

In [46]:
print(pi1)
print(val1)
print(count1)
print(pi2)
print(val2)
print(count2)

{0: 0, 1: 0, 2: 2, 3: 2, 4: 0, 5: 1, 6: 1, 7: 1, 8: 0, 9: 0, 10: 2, 11: 1, 12: 0, 13: 0, 14: 0, 15: 0}
{0: 0.43914104477415017, 1: 0.06271919566139295, 2: 0.19502242597415514, 3: 0.3131640839299231, 4: 0.43914104496688283, 5: 0.25797782817902487, 6: 0.3273256563050027, 7: 0.4313057419218617, 8: 0.43914104523944786, 9: 0.2720732439418666, 10: 0.5289767147874047, 11: 0.653427485563727, 12: 0.4391410454321804, 13: 0.35560714475516475, 14: 0.9875312441259391, 15: 0.0}
2
{0: 0, 1: 0, 2: 2, 3: 2, 4: 3, 5: 1, 6: 2, 7: 2, 8: 0, 9: 0, 10: 2, 11: 2, 12: 0, 13: 0, 14: 0, 15: 0}
{0: 0.42307692232997196, 1: 0, 2: 0.9230769226622877, 3: 0.9999999995853646, 4: 0.4230769224779011, 5: 0.42307692274446385, 6: 0.8461538458213345, 7: 0.9999999996674882, 8: 0.4230769219976559, 9: 0, 10: 0.6153846152000852, 11: 0.9999999998154698, 12: 0.4230769217311454, 13: 0.21153846079162325, 14: 0, 15: 0}
319


In [51]:
def test_policy(pi, env, goalstate):
    # Complete this function to test the policy
    state = 0
    while True:
        action = pi[state]
        states = []
        prob = []
        dones = {}
        for thing in env[state][action]:
            prob.append(thing[0])
            states.append(thing[2])
            dones[thing[2]] = thing[3]
            
        next_state = random.choices(states, prob, k=1)[0]
        state = next_state
        
        if state == goalstate:
            return True
        
        if dones[state] == 1:
            return False
        
        
    return False

In [22]:
# import random

# s = 0
# f = 0
# for _ in range(100):
#     if test_policy(pi2, mdp2, 15):
#         s += 1
#     else:
#         f += 1
        
# print("Agent reaches goal", s, "times")
# print("Agent fails", f, "times")
    