In [1]:
class CliffWalkingEnv:
    def __init__(self, ncol=12, nrow=4) -> None:
        self.ncol = ncol
        self.nrow = nrow
        # self.nstate = self.ncol * self.nrow
        # self.cliff_states = set()
        # self.end_states = set().add(self.nstate-1)
        self.P = self.createP()

    def createP(self):
        # states * actions
        nstate = self.ncol * self.nrow
        P = [[[] for j in range(4)] for i in range(nstate)]

        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]

        for i in range(self.nrow):
            for j in range(self.ncol):
                state = i * self.ncol + j
                # if i == self.nrow - 1 and j > 0 and j != self.ncol - 1:
                    # self.cliff_states.add(state)

                for a in range(4):
                    if i == self.nrow - 1 and j > 0 and j < self.ncol - 1:
                        P[state][a] = [(1, i*self.ncol + j, -100, True)]
                        continue
                    elif i == self.nrow - 1 and j == self.ncol - 1:
                        P[state][a] = [(1, i*self.ncol + j, 0, True)]
                        continue
                    
                    next_x = min(self.ncol-1, max(0, j+change[a][0]))
                    next_y = min(self.nrow-1, max(0, i+change[a][1]))
                    next_state = next_y * self.ncol + next_x
                    reward = -1
                    done = False

                    if next_y == self.nrow - 1 and next_x > 0:
                        done = True
                        if next_x != self.ncol - 1:
                            reward = -100
                    P[state][a] = [(1, next_state, reward, done)]
        return P

    def printP(self):
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):
                    print(self.P[i*self.ncol + j][a])

In [2]:
env = CliffWalkingEnv()
env.printP()

[(1, 0, -1, False)]
[(1, 12, -1, False)]
[(1, 0, -1, False)]
[(1, 1, -1, False)]
[(1, 1, -1, False)]
[(1, 13, -1, False)]
[(1, 0, -1, False)]
[(1, 2, -1, False)]
[(1, 2, -1, False)]
[(1, 14, -1, False)]
[(1, 1, -1, False)]
[(1, 3, -1, False)]
[(1, 3, -1, False)]
[(1, 15, -1, False)]
[(1, 2, -1, False)]
[(1, 4, -1, False)]
[(1, 4, -1, False)]
[(1, 16, -1, False)]
[(1, 3, -1, False)]
[(1, 5, -1, False)]
[(1, 5, -1, False)]
[(1, 17, -1, False)]
[(1, 4, -1, False)]
[(1, 6, -1, False)]
[(1, 6, -1, False)]
[(1, 18, -1, False)]
[(1, 5, -1, False)]
[(1, 7, -1, False)]
[(1, 7, -1, False)]
[(1, 19, -1, False)]
[(1, 6, -1, False)]
[(1, 8, -1, False)]
[(1, 8, -1, False)]
[(1, 20, -1, False)]
[(1, 7, -1, False)]
[(1, 9, -1, False)]
[(1, 9, -1, False)]
[(1, 21, -1, False)]
[(1, 8, -1, False)]
[(1, 10, -1, False)]
[(1, 10, -1, False)]
[(1, 22, -1, False)]
[(1, 9, -1, False)]
[(1, 11, -1, False)]
[(1, 11, -1, False)]
[(1, 23, -1, False)]
[(1, 10, -1, False)]
[(1, 11, -1, False)]
[(1, 0, -1, False)]
[(

In [3]:
import copy

class Policy:
    def __init__(self, env) -> None:
        self.env = env
        self.nstate = self.env.nrow * self.env.ncol
        self.v = [0] * self.nstate 
        self.pi = [[] for i in range(self.nstate)]

    def print_value(self):
        for i in range(env.nrow):
            for j in range(env.ncol):
                print(f"{self.v[i*env.ncol+j]:8.3f}", end=' ')
            print()

    def print_policy(self, holes=[], ends=[], action_symbol=['^', 'v', '<', '>']):
        for i in range(self.env.nrow):
            for j in range(self.env.ncol):
                state = i * self.env.ncol + j
                if state in holes:
                    print('****', end=' ')
                elif state in ends:
                    print('EEEE', end=' ')
                else:
                    a = self.pi[state]
                    pi_str = ''
                    for k in range(len(action_symbol)):
                        if a[k] > 0:
                            pi_str += action_symbol[k]
                        else:
                            pi_str += 'o'
                    print(pi_str, end=' ')
            print()


class PolicyIteration(Policy):
    def __init__(self, env, theta, gamma) -> None:
        super().__init__(env)
        self.pi = [[0.25, 0.25, 0.25, 0.25] for i in range(self.nstate)]
        self.theta = theta
        self.gamma = gamma

    def policy_evaluate(self):
        cnt = 1
        while True:
            max_diff = 0
            new_v = [0] * self.nstate
            for s in range(self.nstate):
                qsa_list = []
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        p, next_state, r, done = res
                        qsa += r + p * self.gamma * self.v[next_state] * (1-done)
                    qsa_list.append(self.pi[s][a] * qsa)
                
                new_v[s] = sum(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        print(f"Policy evaluation finished after {cnt} iterations.")

    def policy_improve(self):
        for s in range(self.nstate):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    p, next_state, r, done = res
                    qsa += r + p * self.gamma * self.v[next_state] * (1-done)
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            self.pi[s] = [1/cntq if q == maxq else 0 for q in qsa_list]
    
    def policy_iterate(self):
        while True:
            self.policy_evaluate()
            old_pi = copy.deepcopy(self.pi)
            self.policy_improve()

            if old_pi == self.pi:
                break

In [4]:
policy = PolicyIteration(env, 0.001, 0.9)
policy.policy_iterate()
policy.print_value()
policy.print_policy(holes=list(range(37, 47)), ends=[47])

Policy evaluation finished after 60 iterations.
Policy evaluation finished after 72 iterations.
Policy evaluation finished after 44 iterations.
Policy evaluation finished after 12 iterations.
Policy evaluation finished after 1 iterations.
  -7.712   -7.458   -7.176   -6.862   -6.513   -6.126   -5.695   -5.217   -4.686   -4.095   -3.439   -2.710 
  -7.458   -7.176   -6.862   -6.513   -6.126   -5.695   -5.217   -4.686   -4.095   -3.439   -2.710   -1.900 
  -7.176   -6.862   -6.513   -6.126   -5.695   -5.217   -4.686   -4.095   -3.439   -2.710   -1.900   -1.000 
  -7.458 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000    0.000 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


In [5]:
class ValueIteration(Policy):
    def __init__(self, env, theta, gamma) -> None:
        super().__init__(env)
        self.theta = theta
        self.gamma = gamma
    
    def value_iterate(self):
        cnt = 0
        while True:
            max_diff = 0
            new_v = [0] * self.nstate
            for s in range(self.nstate):
                qsa_list = []
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        p, next_state, r, done = res
                        qsa += r + p * self.gamma * self.v[next_state]*(1-done)
                    qsa_list.append(qsa)
                
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        print(f"Policy iteration finished after {cnt} iterations.")
        self.get_policy()

    def get_policy(self):
        for s in range(self.nstate):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    p, next_state, r, done = res
                    qsa += r + p * self.gamma * self.v[next_state] * (1-done)
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            self.pi[s] = [1/cntq if q == maxq else 0 for q in qsa_list]

In [6]:
policy = ValueIteration(env, 0.001, 0.9)
policy.value_iterate()
policy.print_value()
policy.print_policy(holes=list(range(37, 47)), ends=[47])

Policy iteration finished after 14 iterations.
  -7.712   -7.458   -7.176   -6.862   -6.513   -6.126   -5.695   -5.217   -4.686   -4.095   -3.439   -2.710 
  -7.458   -7.176   -6.862   -6.513   -6.126   -5.695   -5.217   -4.686   -4.095   -3.439   -2.710   -1.900 
  -7.176   -6.862   -6.513   -6.126   -5.695   -5.217   -4.686   -4.095   -3.439   -2.710   -1.900   -1.000 
  -7.458 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000 -100.000    0.000 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


In [7]:
import gym

env = gym.make("FrozenLake-v1")
env = env.unwrapped
# env.render()

holes = set()
ends = set()
for state in env.P:
    for action in env.P[state]:
        for s in env.P[state][action]:
            if s[2] == 1:
                ends.add(s[1])
            if s[3] is True:
                holes.add(s[1])
holes = holes - ends
print(holes)
print(ends)

  deprecation(
  deprecation(
If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


AttributeError: 'FrozenLakeEnv' object has no attribute 's'

In [35]:
agent = PolicyIteration(env, 1e-5, 0.9)
agent.policy_iterate()
agent.print_value()
agent.print_policy(holes=list(holes), ends=list(ends), action_symbol=['<', 'v', '>', '^'])

Policy evaluation finished after 29 iterations.
Policy evaluation finished after 66 iterations.
   0.207    0.184    0.223    0.167 
   0.276    0.000    0.337    0.000 
   0.436    0.742    0.899    0.000 
   0.000    1.140    1.917    0.000 
<ooo ooo^ <ooo ooo^ 
<ooo **** <o>o **** 
ooo^ ovoo <ooo **** 
**** oo>o ovoo EEEE 


In [36]:
agent = ValueIteration(env, 1e-5, 0.9)
agent.value_iterate()
agent.print_value()
agent.print_policy(holes=list(holes), ends=list(ends), action_symbol=['<', 'v', '>', '^'])

Policy iteration finished after 68 iterations.
   0.207    0.184    0.223    0.167 
   0.276    0.000    0.337    0.000 
   0.436    0.742    0.899    0.000 
   0.000    1.140    1.917    0.000 
<ooo ooo^ <ooo ooo^ 
<ooo **** <o>o **** 
ooo^ ovoo <ooo **** 
**** oo>o ovoo EEEE 
