In [4]:
import pandas
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [2]:
states = list(range(0,16))
naive_value_function = {} 
for state in states:
    naive_value_function[state] = 0

init_policy_dict = {}
for state in states:
    if state not in [0, 15]:
        init_policy_dict[state] = ['up', 'down', 'left', 'right']
    else:
        init_policy_dict[state] = []

In [168]:
def update_state(old_state, action):
    translate_dict = {
        'up': -4, 
        'down': 4, 
        'left': -1, 
        'right': 1
    }
    
    restricted_moves = {
        'up': [1,2,3],
        'down': [12, 13, 14],
        'left': [4, 8, 12],
        'right': [3, 7, 11]
    }
    
    if (old_state==0) or (old_state==15) or (old_state in restricted_moves[action]):
        return old_state
    else:
        return np.clip(old_state+translate_dict[action], 0, 15)

def calc_reward(state, action):
    if state not in [0, 15]:
        return -1
    else:
        return 0
    
def state_evaluation_max_action_selection(value_function, state, actions):
    act_rew = {}
    for act in actions:
        new_state = update_state(state, act)
        act_rew[act] = calc_reward(state, act)+value_function[new_state]    
    if act_rew:
        return {'max': [max(act_rew, key=act_rew.get)], 'mean': numpy.mean(list(act_rew.values()))}
    else:
        return {'max': [], 'mean': 0}

def policy_evaluation(policy_dict, value_function):
    new_val_func = value_function.copy()
    all_difs = []
    for count in range(100):
        for state in new_val_func:
            actions = policy_dict[state]
            old_val = new_val_func[state]

            new_state_val = state_evaluation_max_action_selection(new_val_func, state, actions)['mean']
            all_difs.append(abs(old_val-new_state_val))
            new_val_func[state] = new_state_val
        if not count%20: print(count, numpy.mean(all_difs))
#         print(new_val_func[(5,5)])
    return new_val_func

def improve_policy(policy_dict, value_function):
    new_val_func = value_function.copy()
    new_policy_dict = {}
    for state in new_val_func:
        actions = policy_dict[state]
        best_act = state_evaluation_max_action_selection(new_val_func, state, actions)['max']
        new_policy_dict[state] = best_act
    return new_policy_dict

def iterate_policy(policy_dict, value_function):
    flag = True
    all_policies = [policy_dict]
    all_val_funcs = [value_function]
    counter=0
    while flag:
        all_val_funcs.append(policy_evaluation(all_policies[-1], all_val_funcs[-1]))
        all_policies.append(improve_policy(all_policies[-1], all_val_funcs[-1]))
        
        if all_policies[-1] == all_policies[-2]:
            flag = False
        counter+=1
        print(counter)
        print('^'*10)
    return all_policies[-1]
        

In [149]:
new_value_function = policy_evaluation(init_policy_dict, naive_value_function)

0 1.3212890625
20 0.6399909591413854
40 0.3793993618226454
60 0.26102921516646205
80 0.19736530410717915


In [150]:
pd.DataFrame(np.reshape(list(new_value_function.values()), (int(np.sqrt(len(states))), -1)))

Unnamed: 0,0,1,2,3
0,0.0,-13.997658,-19.996634,-21.996295
1,-13.997658,-17.997127,-19.99688,-19.996916
2,-19.996634,-19.99688,-17.997367,-13.998034
3,-21.996295,-19.996916,-13.998034,0.0


In [151]:
new_policy = improve_policy(init_policy_dict, new_value_function)

In [161]:
sum(list(new_policy.values()), [])

['left',
 'left',
 'left',
 'up',
 'up',
 'left',
 'down',
 'up',
 'up',
 'down',
 'down',
 'up',
 'right',
 'right']

In [163]:
words_to_arrows={
    None : '',
    'left': '<-',
    'right': '->',
    'up': '↑',
    'down': '↓'
}
pd.DataFrame(np.reshape(list(new_policy.values()), (int(np.sqrt(len(states))), -1))).replace(words_to_arrows)
# 

  result = getattr(asarray(obj), method)(*args, **kwds)


Unnamed: 0,0,1,2,3
0,[],[left],[left],[left]
1,[up],[up],[left],[down]
2,[up],[up],[down],[down]
3,[up],[right],[right],[]


In [153]:
new_new_value_function = policy_evaluation(new_policy, new_value_function)

0 11.810326532230837
20 0.6784503804466153
40 0.34749897535070545
60 0.2335648850717856
80 0.17589454307875213


In [154]:
pd.DataFrame(np.reshape(list(new_new_value_function.values()), (int(np.sqrt(len(states))), -1)))

Unnamed: 0,0,1,2,3
0,0.0,-1.0,-2.0,-3.0
1,-1.0,-2.0,-3.0,-2.0
2,-2.0,-3.0,-2.0,-1.0
3,-3.0,-2.0,-1.0,0.0


In [155]:
new_new_policy = improve_policy(new_policy, new_new_value_function)

In [156]:
pd.DataFrame(np.reshape(list(new_new_policy.values()), (int(np.sqrt(len(states))), -1))).replace(words_to_arrows)

  result = getattr(asarray(obj), method)(*args, **kwds)


Unnamed: 0,0,1,2,3
0,[],[left],[left],[left]
1,[up],[up],[left],[down]
2,[up],[up],[down],[down]
3,[up],[right],[right],[]


In [169]:
final_policy = iterate_policy(init_policy_dict, naive_value_function)

0 1.3212890625
20 0.6399909591413854
40 0.3793993618226454
60 0.26102921516646205
80 0.19736530410717915
1
^^^^^^^^^^
0 11.810326532230837
20 0.6784503804466153
40 0.34749897535070545
60 0.2335648850717856
80 0.17589454307875213
2
^^^^^^^^^^
