In [24]:
import copy


class CliffWalkingEnv:
    """ 悬崖漫步环境"""
    def __init__(self, ncol=12, nrow=4):
        self.ncol = ncol  # 定义网格世界的列
        self.nrow = nrow  # 定义网格世界的行
        # 转移矩阵P[state][action] = [(p, next_state, reward, done)]包含下一个状态和奖励
        self.P = self.createP()

    def createP(self):
        # 初始化
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        # 4种动作, change[0]:上,change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
        # 定义在左上角
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):
                    # 位置在悬崖或者目标状态,因为无法继续交互,任何动作奖励都为0
                    if i == self.nrow - 1 and j > 0:
                        P[i * self.ncol + j][a] = (1, i * self.ncol + j, 0,
                                                    True)
                        continue
                    # 其他位置
                    next_x = min(self.ncol - 1, max(0, j + change[a][0]))
                    next_y = min(self.nrow - 1, max(0, i + change[a][1]))
                    next_state = next_y * self.ncol + next_x
                    reward = -1
                    done = False
                    # 下一个位置在悬崖或者终点
                    if next_y == self.nrow - 1 and next_x > 0:
                        done = True
                        if next_x != self.ncol - 1:  # 下一个位置在悬崖
                            reward = -100
                    P[i * self.ncol + j][a] = (1, next_state, reward, done)
        return P

In [31]:
class PolicyIteration:
    
    def __init__(self,env,theta,gamma):
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow
        self.pi = [[0.25,0.25,0.25,0.25] for i in range(self.env.ncol *self.env.nrow)]
        self.theta = theta
        self.gamma = gamma
    
    def policy_evaluation(self):
        cnt = 1
        while 1:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    p,next_state,r,done = self.env.P[s][a]
                    qsa = p * (r + self.gamma * self.v[next_state]*(1-done))
                    qsa_list.append(self.pi[s][a]*qsa)
                new_v[s] = sum(qsa_list)
#                 print(new_v[s])
                max_diff = max(max_diff,abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        print('policy finished %d' % cnt)
    
    def policy_improvement(self):
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                p,next_state,r,done = self.env.P[s][a]
                qsa = p * (r + self.gamma * self.v[next_state]*(1-done))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            self.pi[s] = [1/cntq if q == maxq else 0 for q in qsa_list]
        print('policy improve finished')
        return self.pi
    
    def policy_iteration(self):
        while 1:
            self.policy_evaluation()
            old_pi = copy.deepcopy(self.pi)
            new_pi = self.policy_improvement()
            if old_pi == new_pi:
                break

In [32]:
def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("状态价值：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 为了输出美观,保持输出6个字符
            print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]), end=' ')
        print()

    print("策略：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 一些特殊的状态,例如悬崖漫步中的悬崖
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:  # 目标状态
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()


env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = PolicyIteration(env, theta, gamma)
agent.policy_iteration()
print_agent(agent, action_meaning, list(range(37, 47)), [47])

policy finished 60
policy improve finished
policy finished 72
policy improve finished
policy finished 44
policy improve finished
policy finished 12
policy improve finished
policy finished 1
policy improve finished
状态价值：
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 
策略：
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


In [46]:
class ValueIteration:
    
    def __init__(self,env,theta,gamma):
        
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow
        self.theta = theta
        self.gamma = gamma
        self.pi = [None for i in range(self.env.ncol * self.env.nrow)]
    
    def value_iteration(self):
        cnt = 0
        while 1:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    p,next_state,r,done = self.env.P[s][a]
                    qsa = p * (r + self.gamma * self.v[next_state]*(1-done))
                    qsa_list.append(qsa)
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff,abs(new_v[s]-self.v[s]))
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        print('value iteration finished',cnt)
        self.get_policy()
    
    def get_policy(self):
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                p,next_state,r,done = self.env.P[s][a]
                qsa = p * (r + self.gamma * self.v[next_state]*(1-done))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            self.pi[s] = [1/cntq if q == maxq else 0 for q in qsa_list]
                
                

In [47]:
env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = ValueIteration(env, theta, gamma)
agent.value_iteration()
print_agent(agent, action_meaning, list(range(37, 47)), [47])

value iteration finished 14
状态价值：
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 
策略：
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


In [77]:
class ValueIteration:
    
    def __init__(self,env):
        self.env = env
        self.gamma = 0.9
        self.v = [0] * env.ncol*env.nrow
        self.pi = [[0.25,0.25,0.25,0.25] for _ in range(env.ncol*env.nrow)]
    
    def value_iteration(self):
        cnt = 0
        while 1:
            diff = 0
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    p,next_state,r,done = self.env.P[s][a]
                    qsa = p * r + p * self.gamma * self.v[next_state]
                    qsa_list.append(qsa)
                maxa = max(qsa_list)
                diff = max(diff,abs(maxa-self.v[s]))
                cnta = qsa_list.count(maxa)
                new_v = [1/cnta if a == maxa else 0 for a in qsa_list]
                self.pi[s] = new_v
                self.v[s] = maxa
            cnt += 1
            if diff < 0.001:
                break
        print('iterate times',cnt)

In [78]:
env.P[0][1]

(1, 12, -1, False)

In [79]:
env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = ValueIteration(env)
agent.value_iteration()
print_agent(agent, action_meaning, list(range(37, 47)), [47])

iterate times 15
状态价值：
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 
策略：
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


In [82]:
agent.pi

[[0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 1.0, 0, 0],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 0.5, 0, 0.5],
 [0, 1.0, 0, 0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 0, 0, 1.0],
 [0, 1.0, 0, 0],
 [1.0, 0, 0, 0],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25],
 [0.25, 0.25, 0.25, 0.25]]

In [80]:
class PolicyIteration:
    
    def __init__(self,env):
        self.env = env
        self.gamma = 0.9
        self.v = [0 for _ in range(env.ncol*env.nrow)]
        self.pi = [[0.25,0.25,0.25,0.25] for _ in range(env.ncol*env.nrow)]
    
    def policy_evaluate(self):
        cnt = 0
        while 1:
            diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol*self.env.nrow):
                qsa_list = []
                for a in range(4):
                    p,next_state,r,done = self.env.P[s][a]
                    qsa = p * r + p * self.gamma * self.v[next_state]
                    qsa_list.append(self.pi[s][a]*qsa)
                new_v[s] = sum(qsa_list)
                diff = max(diff,abs(self.v[s]-new_v[s]))
            self.v = new_v
            cnt += 1
            if diff < 0.001:
                break
        print('policy evaluate',cnt)
        return self.pi
    
    def policy_improvement(self):
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                p,next_state,r,done = self.env.P[s][a]
                qsa = p * r + p * self.gamma * self.v[next_state]
                qsa_list.append(qsa)
            maxa = max(qsa_list)
            cnta = qsa_list.count(maxa)
            self.pi[s] = [1/cnta if a == maxa else 0 for a in qsa_list]
        return self.pi
        
    
    def policy_iterate(self):
        cnt = 0
        while 1:
            pi_old = self.policy_evaluate()
            pi_old = copy.deepcopy(pi_old)
            pi_new = self.policy_improvement()
            cnt += 0
            if pi_new == pi_old:
                break
        print('total process times',cnt)
            

In [76]:
env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = PolicyIteration(env)
agent.policy_iterate()
print_agent(agent, action_meaning, list(range(37, 47)), [47])

policy evaluate 60
policy evaluate 72
policy evaluate 44
policy evaluate 12
policy evaluate 1
total process times 0
状态价值：
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 
策略：
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 
