In [55]:
import numpy as np

ACTIONS = ('U', 'D', 'L', 'R')
DELTA_THRESHOLD = 1e-3
GAMMA = 0.9

In [56]:
class Grid():
    
    def __init__(self, rows, cols, start):
        self.rows = rows
        self.cols = cols
        self.i = start[0]
        self.j = start[1]
        
    def set(self, rewards, actions):
        self.rewards = rewards
        self.actions = actions
        
    def set_state(self, s):
        self.i = s[0]
        self.j = s[1]
        
    def current_state(self):
        return(self.i, self.j)
    
    def is_terminal(self, s):
        return s not in self.actions
    
    def move(self, action):
        if action in self.actions[(self.i, self.j)]:
            if action == 'U':
                self.i -= 1
            elif action == 'D':
                self.i += 1
            elif action == 'R':
                self.j += 1
            elif action == 'L':
                self.j -= 1
        return self.rewards.get((self.i, self.j), 0)
    
    def all_states(self):
        return set(self.actions.keys()) | set(self.rewards.keys())
    

In [81]:
hmm = {'a':1, 'b':2, 'c':3}
set(hmm.keys())

{'a', 'b', 'c'}

In [64]:
def standard_grid():
    grid = Grid(3, 4, (2,0))
    rewards = {(0,3):1, (1,3):-1}
    actions = {
		(0, 0): ('D', 'R'),
		(0, 1): ('L', 'R'),
		(0, 2): ('L', 'D', 'R'),
		(1, 0): ('U', 'D'),
		(1, 2): ('U', 'D', 'R'),
		(2, 0): ('U', 'R'),
		(2, 1): ('L', 'R'),
		(2, 2): ('L', 'R', 'U'),
		(2, 3): ('L', 'U'),
    }
    grid.set(rewards, actions)
    return grid

def print_values(V, grid):
    for i in range(grid.rows):
        print("--------------------------")
        for j in range(grid.cols):
            value = V.get((i, j), 0)
            if value >= 0:
                print(f"{value:.2f} | ", end ="")
            else:
                print(f"{value:.2f}| ", end ="")
        print("")
        
def print_policy(P, grid):
    for i in range(grid.rows):
        print("--------------------------")
        for j in range(grid.cols):
            action = P.get((i, j), ' ')
            print(f"  {action}  |", end="")
        print("")
        

In [65]:
smp = standard_grid()
print_values(smp.rewards, smp)
policy = {}
for i in smp.actions.keys():
    policy[i] = np.random.choice(ACTIONS)
print_policy(policy, smp)

--------------------------
0.00 | 0.00 | 0.00 | 1.00 | 
--------------------------
0.00 | 0.00 | 0.00 | -1.00| 
--------------------------
0.00 | 0.00 | 0.00 | 0.00 | 
--------------------------
  U  |  U  |  D  |     |
--------------------------
  L  |     |  U  |     |
--------------------------
  L  |  R  |  D  |  L  |


--------------------------
0.04 | 0.97 | 0.21 | 0.00 | 
--------------------------
0.55 | 0.00 | 0.10 | 0.00 | 
--------------------------
0.79 | 0.13 | 0.83 | 0.63 | 


In [87]:
V = {}
states = smp.all_states()
for s in states:
    if s in smp.actions:
        V[s] = np.random.random()
    else:
        V[s] = 0
print_values(V, smp)
print("")

loop = 0

while True:
    maxChange = 0
    for s in states:
        oldValue = V[s]
        
        if s in policy:
            newValue = float('-inf')
            for a in ACTIONS:
                smp.set_state(s)
                r = smp.move(a)
                v = r + GAMMA*V[smp.current_state()]
                if v > newValue:
                    newValue = v
            V[s] = newValue
            maxChange = max(maxChange, np.abs(oldValue - V[s]))
    print_values(V, smp)
    print("")
    loop += 1
    
    if maxChange < DELTA_THRESHOLD:
        break
    

--------------------------
0.24 | 0.76 | 0.79 | 0.00 | 
--------------------------
0.48 | 0.00 | 0.19 | 0.00 | 
--------------------------
0.67 | 0.72 | 0.83 | 0.45 | 

--------------------------
0.64 | 0.71 | 1.00 | 0.00 | 
--------------------------
0.61 | 0.00 | 0.75 | 0.00 | 
--------------------------
0.68 | 0.75 | 0.75 | 0.75 | 

--------------------------
0.81 | 0.90 | 1.00 | 0.00 | 
--------------------------
0.73 | 0.00 | 0.90 | 0.00 | 
--------------------------
0.61 | 0.68 | 0.81 | 0.68 | 

--------------------------
0.81 | 0.90 | 1.00 | 0.00 | 
--------------------------
0.73 | 0.00 | 0.90 | 0.00 | 
--------------------------
0.66 | 0.73 | 0.81 | 0.73 | 

--------------------------
0.81 | 0.90 | 1.00 | 0.00 | 
--------------------------
0.73 | 0.00 | 0.90 | 0.00 | 
--------------------------
0.66 | 0.73 | 0.81 | 0.73 | 



In [88]:
for s in policy.keys():
    bestAction = None
    bestValue = float('-inf')
    
    for a in ACTIONS:
        smp.set_state(s)
        r = smp.move(a)
        v = r + GAMMA * V[smp.current_state()]
        if v > bestValue:
            bestValue = v
            bestAction = a
        policy[s] = bestAction
        

In [89]:
print_values(V, smp)
print("")
print_policy(policy, smp)

--------------------------
0.81 | 0.90 | 1.00 | 0.00 | 
--------------------------
0.73 | 0.00 | 0.90 | 0.00 | 
--------------------------
0.66 | 0.73 | 0.81 | 0.73 | 

--------------------------
  R  |  R  |  R  |     |
--------------------------
  U  |     |  U  |     |
--------------------------
  U  |  R  |  U  |  L  |
