In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import numpy as np

In [6]:
states = np.load('rook_final.npy').item()

In [7]:
def get_deterministic_policy(states):
    pi = {}
    for state, value in states.items():
        pi[state] = list(value.keys())[0]
    return pi

In [8]:
def get_deterministic_policy_uniform(states):
    pi = {}
    for state, value in states.items():
        pi[state] = np.random.choice(list(value.keys()))
    return pi

In [9]:
def deterministic_policy_eval_step(states_actions, V, pi):
    # Evaluation in place (in contrast with evaluation with 2 arrays).
    # Needs less memory and converges too
    # pi is a dict and pi[s] is the best action for that state. (The most probable action)
    delta = 0
    for state, actions in states_actions.items():
        action = pi[state]
        next_node = actions[action]['next_state']
        reward = actions[action]['status']
        V_updated = 0
        if next_node in V:
            V_updated = -(reward + V[next_node])
        else:
            V_updated = -reward
        delta = max(delta, np.abs(V_updated - V[state]))
        V[state] = V_updated
    return V, delta

In [10]:
def policy_improve(V, states_actions):
    pi = {}
    for state, actions in states_actions.items():
        actions_list = [] # list(actions.keys())
        expected_rewards = [] #np.zeros(len(actions))
        for i, (action, data) in enumerate(actions.items()):
            actions_list.append(action)
            next_state = data['next_state']
            reward = data['status']
            if next_state in V:
                expected_rewards.append(-(reward + V[next_state]))
            else:
                expected_rewards.append(-reward)

        pi[state] = actions_list[np.argmax(expected_rewards)]
        if state == '4k3/8/4K2R/8/8/8/8/8 w':
            print(np.argmax(expected_rewards))
            print(actions_list)
            print(expected_rewards)
    return pi

In [11]:
from dynamic_programming import policy_iteration

In [13]:
pi = get_deterministic_policy(states)
# pi = get_deterministic_policy_uniform(states)
pi, V = policy_iteration(states, 
                             pi, 
                             deterministic_policy_eval_step = deterministic_policy_eval_step, 
                             policy_improve=policy_improve, 
                             verbose = 1)

Iteration number:  1 2 3 4 5 
0
['h6h8', 'h6h7', 'h6g6', 'h6f6', 'h6h5', 'h6h4', 'h6h3', 'h6h2', 'h6h1', 'e6f6', 'e6d6', 'e6f5', 'e6e5', 'e6d5']
[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
Number of differences of new policy vs old policy: 3905
---------------------------
Iteration number:  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 
0
['h6h8', 'h6h7', 'h6g6', 'h6f6', 'h6h5', 'h6h4', 'h6h3', 'h6h2', 'h6h1', 'e6f6', 'e6d6', 'e6f5', 'e6e5', 'e6d5']
[1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
Number of differences of new policy vs old policy: 53045
---------------------------
Iteration number:  1 2 3 4 5 6 7 8 9 10 11 
0
['h6h8', 'h6h7', 'h6g6', 'h6f6', 'h6h5', 'h6h4', 'h6h3', 'h6h2', 'h6h1', 'e6f6', 'e6d6', 'e6f5', 'e6e5', 'e6d5']
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Number of differences of new policy vs old policy: 56678
---------------------------
Iteration number:  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 
0
['h6h8', 'h6h7', 'h6g6', 'h6f6', 'h6h5', 'h6h4', 'h6h3', 'h6h2', '

KeyboardInterrupt: 