In [2]:
import numpy as np
import gym

import pyprind
from time import time

In [177]:
P_a0 = np.array([
        [0.5, 0.5],
        [0.0, 1.0]
    ])
P_a1 = np.array([
        [0.0, 1.0],
        [0.0, 1.0]
    ])
P_a2 = np.array([
        [1.0, 0.0],
        [0.0, 1.0]
    ])
P_all = np.concatenate([P_a0, P_a1, P_a2], axis=0)
P_list = [P_a0, P_a1, P_a2]

In [178]:
R_all = np.array([
        [5.0, 10.0, -np.inf],
        [-np.inf, -np.inf, -1.0]
    ])

In [179]:
gamma = 0.95
tol = 1e-10

In [180]:
V = np.zeros(2)

In [16]:
gamma*np.dot(P_all, V).reshape(R_all.shape, order='F')

array([[ 0.,  0.,  0.],
       [ 0.,  0.,  0.]])

In [17]:
start = time()
for i in range(1000):
    V_old = np.copy(V)
    V = np.max(R_all + gamma*np.dot(P_all, V).reshape(R_all.shape, order='F'), axis=1)
    error = np.linalg.norm(V-V_old, 2)
    if error < tol:
        print('Converged after {:} iterations with error {:}'.format(i+1, error))
        break
print('Time taken:', time() - start, 's')
V


Converged after 457 iterations with error 9.828401969972982e-11
Time taken: 0.014429569244384766 s


array([ -8.57142857, -20.        ])

In [18]:
Q = -1000*np.ones((2,3))
A = [[0, 1], [2]]
start = time()
for i in range(1000):
    for s in range(2):
        for a in A[s]:
            res = R_all[s, a]
            for sprime in range(2):
                #print(res)
                res += gamma*P_list[a][s, sprime]*np.max(Q[sprime])
            Q[s,a] = res
V = np.max(Q, axis=1)
print('Time taken:', time() - start, 's')
V

Time taken: 0.0657036304473877 s


array([ -8.57142857, -20.        ])

In [183]:
V_pi = np.zeros(2)
pi = np.zeros(2, dtype=int)
for i in range(900000):
    pi_old = np.copy(pi)
    P_pi = np.array([P_list[pi[0]][0],P_list[pi[1]][1]])
    R_pi = np.array([R_all[0, pi[0]], R_all[1, pi[1]]])
    #V_pi = np.linalg.solve(a=np.eye(len(P_pi))-gamma*P_pi, b=R_pi) #Policy Iteration
    V_pi = np.max(R_all + gamma*np.dot(P_all, V_pi).reshape(R_all.shape, order='F'), axis=1) #Modified Policy Iteration
    pi = np.argmax(R_all + gamma*np.dot(P_all, V_pi).reshape(R_all.shape, order='F'), axis=1)
    if np.allclose(pi, pi_old):
        print('Converged after {:} iterations'.format(i+1))
        break
        pass
pi

Converged after 2 iterations


array([0, 2])

In [184]:
V_pi

array([ 9.275, -1.95 ])

In [167]:
np.array([P_list[pi[0]][0],P_list[pi[1]][1]])

array([[ 0.,  1.],
       [ 0.,  1.]])

In [107]:
env = gym.make("FrozenLake-v0")
#env = gym.make("Taxi-v2")

In [71]:
N = 10000
Q = np.zeros((env.env.nS, env.env.nA))
gamma = 0.1
tol = 1e-6
bar = pyprind.ProgBar(N)
for i in range(N):
    bar.update()
    Q_old = np.copy(Q)
    for s in range(env.env.nS):
        for a in list(env.env.P[s].keys()):
            res = 0.0
            for j, sprime in enumerate([k[1] for k in env.env.P[s][a]]):
                res += env.env.P[s][a][j][0]*(env.env.P[s][a][j][2] + gamma*np.max(Q[sprime]))
            Q[s,a] = res
    error = np.linalg.norm(Q-Q_old,2)
    if error < tol:
        print('Converged after {:} iterations and error is {:}'.format(i+1,error))
        break
print(bar)
V = np.max(Q, axis=1)

Converged after 6 iterations and error is 8.513109932571311e-08
Title: 
  Started: 02/02/2018 18:56:44
  Finished: 02/02/2018 18:56:44
  Total time elapsed: 00:00:00


In [48]:
env.env.P[5].keys()

dict_keys([0, 1, 2, 3])

In [44]:
env.env.nS

16

In [60]:
env.env.nS

500

In [73]:
V_policy = np.zeros(env.env.nS)
policy = np.zeros(env.env.nS, dtype=int)
 

In [77]:
env.env.P[10]

{0: [(0.3333333333333333, 6, 0.0, False),
  (0.3333333333333333, 9, 0.0, False),
  (0.3333333333333333, 14, 0.0, False)],
 1: [(0.3333333333333333, 9, 0.0, False),
  (0.3333333333333333, 14, 0.0, False),
  (0.3333333333333333, 11, 0.0, True)],
 2: [(0.3333333333333333, 14, 0.0, False),
  (0.3333333333333333, 11, 0.0, True),
  (0.3333333333333333, 6, 0.0, False)],
 3: [(0.3333333333333333, 11, 0.0, True),
  (0.3333333333333333, 6, 0.0, False),
  (0.3333333333333333, 9, 0.0, False)]}

In [108]:

nA, nS = env.env.nA, env.env.nS



In [109]:
T = np.zeros([nS, nA, nS])
R = np.zeros([nS, nA, nS])

In [110]:
for s in range(nS):
    for a in range(nA):
        transitions = env.env.P[s][a]
        for p_trans,next_s,rew,done in transitions:
            T[s,a,next_s] += p_trans
            R[s,a,next_s] = rew
        T[s,a,:]/=np.sum(T[s,a,:])

In [123]:
# for s in range(nS):
#      for a in range(nA):
print(env.env.P[10][2])

[(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 11, 0.0, True), (0.3333333333333333, 6, 0.0, False)]


In [144]:
for s in range(nS):
      for a in range(nA):
             for j, sprime in enumerate([k[1] for k in env.env.P[s][a]]):
                    print(j,sprime)
                    

0 0
1 0
2 4
0 0
1 4
2 1
0 4
1 1
2 0
0 1
1 0
2 0
0 1
1 0
2 5
0 0
1 5
2 2
0 5
1 2
2 1
0 2
1 1
2 0
0 2
1 1
2 6
0 1
1 6
2 3
0 6
1 3
2 2
0 3
1 2
2 1
0 3
1 2
2 7
0 2
1 7
2 3
0 7
1 3
2 3
0 3
1 3
2 2
0 0
1 4
2 8
0 4
1 8
2 5
0 8
1 5
2 0
0 5
1 0
2 4
0 5
0 5
0 5
0 5
0 2
1 5
2 10
0 5
1 10
2 7
0 10
1 7
2 2
0 7
1 2
2 5
0 7
0 7
0 7
0 7
0 4
1 8
2 12
0 8
1 12
2 9
0 12
1 9
2 4
0 9
1 4
2 8
0 5
1 8
2 13
0 8
1 13
2 10
0 13
1 10
2 5
0 10
1 5
2 8
0 6
1 9
2 14
0 9
1 14
2 11
0 14
1 11
2 6
0 11
1 6
2 9
0 11
0 11
0 11
0 11
0 12
0 12
0 12
0 12
0 9
1 12
2 13
0 12
1 13
2 14
0 13
1 14
2 9
0 14
1 9
2 12
0 10
1 13
2 14
0 13
1 14
2 15
0 14
1 15
2 10
0 15
1 10
2 13
0 15
0 15
0 15
0 15


In [152]:
[k[1] for k in env.env.P[0][1][]]

TypeError: 'float' object is not subscriptable

In [153]:
j

0

In [148]:
a

3