In [1]:
import sys
sys.path.append("../scripts/")
from GridMap import *
from mdp import *
import math
import copy
from matplotlib.animation import PillowWriter    #アニメーション保存用

In [2]:
class DynamicProgramming():
    def __init__(self, grid_map_world, drawPath=True, drawV=False, drawPI=False):
        self.world = grid_map_world
        self.grid_map = grid_map_world.grid_map
        self.v_map = np.full(self.grid_map.shape, 0)    #状態価値関数
        self.pi_map = np.full(self.grid_map.shape, 1)    #各グリッドにおける方策，初期値1＝右方向
        self.drawPath = drawPath
        self.drawV = drawV
        self.drawPI = drawPI
        
        for x, grids in enumerate(self.grid_map):
            for y, grid in enumerate(grids):
                if(self.isGoal([x, y])):
                    self.v_map[x][y] = 100    #ゴールの初期価値は高めに設定
                    self.pi_map[x][y] = 0    #ゴールでの方策は「停止」
    
    def draw(self, ax, elems):
        self.sweep()    #状態価値関数の更新
        self.updatePI()    #方策改善
                
        #状態価値関数・方策描画
        if(self.drawPI or self.drawV):
            for x in range(len(self.grid_map)):
                for y in range(len(self.grid_map[0])):
                    if(self.isObstacle([x, y])):
                        continue
                    if(self.drawPI):
                        c_num = int(self.PI([x, y]))
                        if(c_num == 0):
                            c = "black"
                        elif(c_num == 1):
                            c = "saddlebrown"
                        elif(c_num == 2):
                            c = "magenta"
                        elif(c_num == 3):
                            c = "blue"
                        elif(c_num == 4):
                            c = "cyan"
                        elif(c_num == 5):
                            c = "green"
                        elif(c_num == 6):
                            c = "lime"
                        elif(c_num == 7):
                            c = "yellow"
                        elif(c_num == 8):
                            c = "orange"
                        elif(c_num == 9):
                            c = "red"
                        else:
                            c = "white"
                        r = patches.Rectangle(
                            xy=(x*self.world.grid_step[0], y*self.world.grid_step[1]),
                            height=self.world.grid_step[0],
                            width=self.world.grid_step[1],
                            color=c,
                            fill=True,
                            alpha=0.5
                        )
                        elems.append(ax.add_patch(r))
                    if(self.drawV):
                        cost_adj = 6    #map1
                        #cost_adj = 7    #map2
                        #cost_adj = 7    #map3
                        c_num = int(-1*self.V([x, y]))+100
                        c_num = int(c_num * cost_adj) #Black→Blue
                        if(c_num > 0xff): #Blue → Cyan
                            c_num = (c_num-0xff)*16*16 + 0xff
                            if(c_num > 0xffff): #Cyan → Green
                                c_num = 0xffff - int((c_num-0x100ff)*4/256)
                                if(c_num < 0xff00): #Green →Yellow
                                    c_num = (0xff00-c_num)*65536+0xff00
                                    if(c_num > 0xffff00): #Yellow → Red
                                        c_num = 0xffff00 - int((c_num-0xffff00)*0.5/65536)*256
                        c = '#' + format(int(c_num), 'x').zfill(6)
                        r = patches.Rectangle(
                            xy=(x*self.world.grid_step[0], y*self.world.grid_step[1]),
                            height=self.world.grid_step[0],
                            width=self.world.grid_step[1],
                            color=c,
                            fill=True,
                            alpha=0.5
                        )
                        elems.append(ax.add_patch(r))
        
        #経路描画
        if(self.drawPath):
            self.showPath(ax, elems)
    
    #価値関数更新
    def sweep(self):
        #全てのグリッドで価値を更新する
        for x in range(len(self.grid_map)):
            for y in range(len(self.grid_map[0])):
                s = [x, y]
                a = self.PI(s)
                v = 0
                for s_ in self.listNeigbor(s):
                    v += self.p(s, a, s_) * (self.R(s, a, s_) + self.V(s_))
                self.v_map[x][y] = v        
    
    #方策改善
    def updatePI(self):
        for x in range(len(self.grid_map)):
            for y in range(len(self.grid_map[0])):
                s = [x, y]
                v = self.V(s)
                pi = self.PI(s)
                if(self.isGoal([x, y])):
                    continue
                a_list = []
                for a in range(9):
                    v = 0
                    for s_ in self.listNeigbor(s):
                        v += self.p(s, a, s_) * (self.R(s, a, s_) + self.V(s_))
                    a_list.append(v)
                self.pi_map[x][y] = np.argmax(a_list)
                
    #状態遷移確率
    def p(self, s, a, s_):
        if(s_[0]<0 or s_[0]>self.grid_map.shape[0]-1): #地図の範囲外か（x軸方向）
            return 0.0
        elif(s_[1]<0 or s_[1]>self.grid_map.shape[1]-1):
            return 0.0
        else:
            if(a == 0):
                if(s_[0]-s[0] == 0 and s_[1]-s[1] == 0):
                    return 1.0
                else:
                    return 0.0                
            elif(a == 1):
                if(s_[0]-s[0] == 1 and s_[1]-s[1] == 0):
                    return 1.0
                else:
                    return 0.0
            elif(a == 2):
                if(s_[0]-s[0] == 1 and s_[1]-s[1] == 1):
                    return 1.0
                else:
                    return 0.0
            elif(a == 3):
                if(s_[0]-s[0] == 0 and s_[1]-s[1] == 1):
                    return 1.0
                else:
                    return 0.0
            elif(a == 4):
                if(s_[0]-s[0] == -1 and s_[1]-s[1] == 1):
                    return 1.0
                else:
                    return 0.0
            elif(a == 5):
                if(s_[0]-s[0] == -1 and s_[1]-s[1] == 0):
                    return 1.0
                else:
                    return 0.0
            elif(a == 6):
                if(s_[0]-s[0] == -1 and s_[1]-s[1] == -1):
                    return 1.0
                else:
                    return 0.0
            elif(a == 7):
                if(s_[0]-s[0] == 0 and s_[1]-s[1] == -1):
                    return 1.0
                else:
                    return 0.0
            elif(a == 8):
                if(s_[0]-s[0] == 1 and s_[1]-s[1] == -1):
                    return 1.0
                else:
                    return 0.0
    
    #価値関数
    def R(self, s, a, s_):
        if(s == self.world.goal_index):
            return 0.0
        if(a == 0):
            return -5.0
        return self.moveCost(s, s_)
    
    #移動コスト関数
    def moveCost(self, s, s_):
        return -math.sqrt((s[0]-s_[0])**2+(s[1]-s_[1])**2)
    
    #状態価値関数
    def V(self, s):
        return self.v_map[s[0]][s[1]]
    
    #方策
    def PI(self, s):
        return self.pi_map[s[0]][s[1]]
    
    def isStart(self, s):
        if(self.grid_map[s[0]][s[1]] == '2'):
            return True
        else:
            return False
    
    def isGoal(self, s):
        if(self.grid_map[s[0]][s[1]] == '3'):
            return True
        else:
            return False
     
    def isObstacle(self, s):
        if(self.grid_map[s[0]][s[1]] == '0'):
            return True
        else:
            return False

    def isOutOfBounds(self, s):
        if(s[0]<0 or s[0]>self.grid_map.shape[0]-1):
            return True
        elif(s[1]<0 or s[1]>self.grid_map.shape[1]-1):
            return True
        return False
    
    def listNeigbor(self, s):
        neigbors = []
        for i in range(-1, 2):
            if(s[0]+i<0 or s[0]+i>self.grid_map.shape[0]-1): #地図の範囲外か（x軸方向）
                continue;
            for j in range(-1, 2):
                if(s[1]+j<0 or s[1]+j>self.grid_map.shape[1]-1): #地図の範囲外か（y軸方向）
                    continue
                elif(self.grid_map[s[0]+i][s[1]+j] == '0'): #障害物か
                    continue
                else:
                    neigbors.append([s[0]+i, s[1]+j])
        return neigbors
    
    def showPath(self, ax, elems):
        self.pathFlag = np.full(self.grid_map.shape, False)
        s = copy.copy(self.world.start_index)
        a = self.PI(s)
        while(not(self.isOutOfBounds(s) or self.isObstacle(s) or self.isGoal(s) or a == 0)):
            if not(self.isStart(s)):
                r = patches.Rectangle(
                    xy=(s[0]*self.world.grid_step[0], s[1]*self.world.grid_step[1]),
                    height=self.world.grid_step[0],
                    width=self.world.grid_step[1],
                    color="red",
                    fill=True,
                    alpha=0.5
                )
                elems.append(ax.add_patch(r))
                
            a = self.PI(s)
            if(a == 0):
                break
            if(a == 1 or a == 2 or a == 8):
                s[0] += 1
            elif(a == 4 or a == 5 or a == 6):
                s[0] -= 1
            if(a == 2 or a == 3 or a == 4):
                s[1] += 1
            elif(a == 6 or a == 7 or a == 8):
                s[1] -= 1

In [3]:
if __name__ == "__main__":
    time_span = 8
    time_interval = 0.1
    
    grid_step = np.array([0.1, 0.1])
    grid_num = np.array([30, 30])
    
    map_data = "../csvmap/map1.csv"
    
    world = GridMapWorld(grid_step, grid_num, time_span, time_interval, map_data, debug=False)
    world.append(DynamicProgramming(world, drawPath=1, drawV=1, drawPI=0))
    
    world.draw()
    #world.ani.save('dynamicProgramming_map2_policy.gif', writer='pillow', fps=100)    #アニメーション保存

<IPython.core.display.Javascript object>