In [1]:
import numpy as np
import gym

In [2]:
def initialization(env):
    values=np.zeros(env.nS)
    policy_val=np.zeros((env.nS,env.nA))
    return values,policy_val

In [3]:
def policy_evaluation(env,values,policy_val,threshold,discount):
    while True:
        delta=0
        for s in range(env.nS):
            value=0
            for a, action_prob in enumerate(policy_val[s]):
                for transition, next_state, reward, done in env.P[s][a]:
                    value += action_prob * transition * (reward + discount * values[next_state])
            delta=max(delta,np.abs(value-values[s]))
            values[s]=value
        if delta < threshold:
            break
    return values

In [4]:
def policy_improvement(env,values,policy_val,discount):
    stability=True
    for s in range(env.nS):
        old_action=np.argmax(policy_val[s])
        action_values=[]
        for a in range(env.nA):
            action=0
            for transition, next_state, reward, done in env.P[s][a]:
                action += transition * (reward + discount * values[next_state])
            action_values.append(action)
        new_act=np.argmax(action_values)
        new_probs=np.zeros(env.nA)
        new_probs[new_act]=1
        policy_val[s]=new_probs
        if old_action != new_act:
            stability=False
    return stability, policy_val

In [5]:
def policy_iteration(env):
    stability=False
    values, policy_val=initialization(env)
    while not stability:
        values=policy_evaluation(env,values,policy_val,threshold=0.01,discount=0.9)
        stability,policy_val=policy_improvement(env,values,policy_val,discount=0.9)
    return policy_val,values

In [15]:
env = gym.make("FrozenLake-v0")

policy_val,values=policy_iteration(env)
values=np.reshape(values,(4,4))
print(values)
state=env.reset()
env.render()
print("above initial state")
done=False
while not done:
    state, reward, done, _ = env.step(np.argmax(policy_val[state]))
    env.render()

[[0.03604488 0.03951734 0.06163689 0.04376275]
 [0.06103737 0.         0.10497305 0.        ]
 [0.12074496 0.23236492 0.29076484 0.        ]
 [0.         0.36893238 0.6335765  0.        ]]

[41mS[0mFFF
FHFH
FFFH
HFFG
above initial state
  (Down)
S[41mF[0mFF
FHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Down)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Down)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
FFFH
HF[41mF[0mG
  (Down)
SFFF


In [7]:
policy_val

array([[0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]])