In [1]:
import numpy as np
import gym
import pyprind
from time import time

In [2]:
P_a0 = np.array([
        [0.5, 0.5],
        [0.0, 1.0]
    ])
P_a1 = np.array([
        [0.0, 1.0],
        [0.0, 1.0]
    ])
P_a2 = np.array([
        [1.0, 0.0],
        [0.0, 1.0]
    ])
P_all = np.concatenate([P_a0, P_a1, P_a2], axis=0)
P_list = [P_a0, P_a1, P_a2]

In [3]:
R_all = np.array([
        [5.0, 10.0, -np.inf],
        [-np.inf, -np.inf, -1.0]
    ])

In [4]:
gamma = 0.95
tol = 1e-10

In [5]:
V = np.zeros(2)

In [6]:
start = time()
for i in range(10000):
    V_old = np.copy(V)
    V = np.max(R_all + gamma*np.dot(P_all, V).reshape(R_all.shape, order='F'), axis=1)
    error = np.linalg.norm(V-V_old, 2)
    if error < tol:
        print('Converged after {:} iterations with error {:}'.format(i+1, error))
        break
print('Time taken:', time() - start, 's')
V


Converged after 457 iterations with error 9.828401969972982e-11
Time taken: 0.014935016632080078 s


array([ -8.57142857, -20.        ])

In [7]:
Q = -1000*np.ones((2,3))
A = [[0, 1], [2]]
start = time()
for i in range(1000):
    for s in range(2):
        for a in A[s]:
            res = R_all[s, a]
            for sprime in range(2):
                #print(res)
                res += gamma*P_list[a][s, sprime]*np.max(Q[sprime])
            Q[s,a] = res
V = np.max(Q, axis=1)
print('Time taken:', time() - start, 's')
V

Time taken: 0.053730010986328125 s


array([ -8.57142857, -20.        ])

In [8]:
V_pi = np.zeros(2)
pi = np.zeros(2, dtype=int)

In [9]:
for i in range(1000):
    pi_old = np.copy(pi)
    P_pi = np.array([P_list[pi[0]][0],P_list[pi[1]][1]])
    R_pi = np.array([R_all[0, pi[0]], R_all[1, pi[1]]])
    #V_pi = np.linalg.solve(a=np.eye(len(P_pi))-gamma*P_pi, b=R_pi) #Policy Iteration
    V_pi = np.max(R_all + gamma*np.dot(P_all, V_pi).reshape(R_all.shape, order='F'), axis=1) #Modified Policy Iteration
    pi = np.argmax(R_all + gamma*np.dot(P_all, V_pi).reshape(R_all.shape, order='F'), axis=1)
    if np.allclose(pi, pi_old):
        print('Converged after {:} iterations'.format(i+1))
        break
        pass
pi

Converged after 2 iterations


array([0, 2])

In [10]:
V_pi

array([ 9.275, -1.95 ])

In [11]:
env = gym.make("FrozenLake-v0")
#env = gym.make("Taxi-v2")

In [12]:
N = 10000
Q = np.zeros((env.env.nS, env.env.nA))
gamma = 0.99
tol = 1e-6
bar = pyprind.ProgBar(N)
for i in range(N):
    bar.update()
    Q_old = np.copy(Q)
    for s in range(env.env.nS):
        for a in list(env.env.P[s].keys()):
            res = 0.0
            for j, sprime in enumerate([k[1] for k in env.env.P[s][a]]):
                res += env.env.P[s][a][j][0]*(env.env.P[s][a][j][2] + gamma*np.max(Q[sprime]))
            Q[s,a] = res
    error = np.linalg.norm(Q-Q_old,2)
    if error < tol:
        print('Converged after {:} iterations and error is {:}'.format(i+1,error))
        break
print(bar)

Converged after 250 iterations and error is 9.789660166799322e-07
Title: 
  Started: 02/01/2018 23:09:26
  Finished: 02/01/2018 23:09:27
  Total time elapsed: 00:00:00


In [13]:
Q

array([[ 0.54202216,  0.52775829,  0.52775829,  0.52233794],
       [ 0.34347063,  0.33419487,  0.31993087,  0.49879819],
       [ 0.43818494,  0.43361628,  0.42434043,  0.47068984],
       [ 0.30608654,  0.30608654,  0.30151778,  0.45684543],
       [ 0.55844761,  0.37958035,  0.37415995,  0.36315502],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.35834555,  0.2030179 ,  0.35834555,  0.15532765],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.37958035,  0.4075083 ,  0.39650337,  0.59179601],
       [ 0.44005996,  0.64307786,  0.44778518,  0.39831058],
       [ 0.61520584,  0.49695181,  0.40299015,  0.33046972],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.45698297,  0.52950339,  0.74171908,  0.49695181],
       [ 0.73252134,  0.86283675,  0.82108739,  0.78111856],
       [ 0.        ,  0.        ,  0.        ,  0.        ]])

In [14]:
V = np.max(Q, axis=1)
V

array([ 0.54202216,  0.49879819,  0.47068984,  0.45684543,  0.55844761,
        0.        ,  0.35834555,  0.        ,  0.59179601,  0.64307786,
        0.61520584,  0.        ,  0.        ,  0.74171908,  0.86283675,  0.        ])