<a href="https://colab.research.google.com/github/deguc/Shannon/blob/main/405_DP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import numpy as np
from collections import defaultdict

class GridWorld:

    def __init__(self):

        self.action = {0:(-1,0),1:(1,0),2:(0,-1),3:(0,1)}
        self.map = np.array([
            [0,0,0,1],
            [0,None,0,-1],
            [0,0,0,0]
        ])
        self.goal = (0,3)
        self.start = (2,0)
        self.agetn_state = self.start
        self.h = self.map.shape[0]
        self.w = self.map.shape[1]
        self.states = [(i,j) for i in range(self.h) for j in range(self.w)]

    def reward(self,s):
        return self.map[s]

    def isGoal(self,s):
        return s == self.goal

    def isWall(self,s):
        return self.reward(s) is None

    def isOut(self,s):

        i,j = s

        if i < 0 or i >= self.h:
            return True
        if j < 0 or j >= self.w:
            return True

        return False

    def next_state(self,s,a):

        s1,s2 = s
        m1,m2 = self.action[a]

        s_ = (s1+m1,s2+m2)

        if self.isOut(s_):
            s_ = s
        elif self.isWall(s_):
            s_ = s

        return s_

    def get_map(self):

        m = np.zeros_like(self.map)

        for s in self.states:
            if self.isWall(s):
                m[*s] = '■'
            elif self.isGoal(s):
                m[*s] = '◎'
            elif self.reward(s) == -1:
                m[*s] = '✕'
            else:
                m[*s] = '口'

        return m

def disp_map(m):

    for i in range(m.shape[0]):

        s = ''

        for j in range(m.shape[1]):

            s += m[i,j]

        print(s)


def delta(x1,x2):

    d1 = np.array([x1[k] for k in x1.keys()])
    d2 = np.array([x2[k] for k in x1.keys()])

    return np.max(np.abs(d1-d2))


def V_disp(V,env):

    for i in range(env.h):
        s = ''
        for j in range(env.w):
            v = V[(i,j)]
            s += f'{v:.2f} '
        print(s)


def eval(pi,V,env):

    for s in env.states:

        if env.isGoal(s):
            V[s] = 0
            continue

        V_ = 0

        for a,p in pi[s].items():

            s_ = env.next_state(s,a)
            r = env.reward(s_)
            V_ += p * (r+0.9*V[s_])

        V[s] = V_

    return V


def policy_eval(pi,V,env):

    while True:

        V_ = V.copy()

        V = eval(pi,V,env)

        if delta(V,V_) < 1e-3:
            break

    return V


def greedy_policy(V,env):

    n = len(env.action)
    pi = {}

    for s in env.states:

        Q = np.zeros(n)

        for a in range(n):

            s_ = env.next_state(s,a)
            r = env.reward(s_)
            Q[a] = r + 0.9*V[s_]

        amax = np.argmax(Q)
        p = {0:0,1:0,2:0,3:0}
        p[amax] = 1
        pi[s] = p

    return pi



def get_action(pi,s):

        po = pi[s]
        p = np.array([v for v in po.values()])

        return np.argmax(p)


def policy_iter(env):

    V = defaultdict(lambda:0)
    pi = defaultdict(lambda:{0:0.25,1:0.25,2:0.25,3:0.25})

    while True:

        V = policy_eval(pi,V,env)
        pi_=greedy_policy(V,env)

        if pi_== pi:
            break

        pi = pi_

    return pi



env = GridWorld()
pi = policy_iter(env)

m = env.get_map()

s = (2,0)

k = 0
h = {0:'↑',1:'↓',2:'←',3:'→'}
while not env.isGoal(s):

    a = get_action(pi,s)
    m[*s] = h[a]
    s_ = env.next_state(s,a)
    s = s_


disp_map(m)


→→→◎
↑■口✕
↑口口口
