In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from gym.envs.registration import register
import random

In [2]:
def rargmax(vector):
    m = np.amax(vector)
    indices = np.nonzero(vector == m)[0]
    return random.choice(indices)

In [3]:
register(
    id='FrozenLake-v3',
    entry_point='gym.envs.toy_text:FrozenLakeEnv',
    kwargs={
        'map_name' : '4x4',
        'is_slippery' : True,
    }
)

In [4]:
env = gym.make('FrozenLake-v0')

## Stochastic world
 - Listen to Q(s') (just a little bit)
 - Update Q(s) little bit (learning rate)
 - you need many mentors

In [40]:
Q = np.zeros([env.observation_space.n, env.action_space.n])

reward_ls = []
num_episodes = 20000

for i in range(num_episodes):
    if i % 1000 == 0: print(i)
        
    state = env.reset()
    r_all = 0
    e_rate = 1. / ((i/100)+1)  # exploration rate (random action)
    discount_rate = 0.99
    alpha = 0.1
    done = False
    
    while not done:
        # Choose an action by e-greedy
        if np.random.rand(1) < e_rate :
            action = env.action_space.sample()
        else:
            action = np.argmax(Q[state, :] + np.random.randn(1, env.action_space.n)/(i+1))
          
        new_state, reward, done, _ = env.step(action)
        
        #Q[state, action] = reward + discount_rate * np.max(Q[new_state, :])     
        Q[state, action] = \
            (1-alpha)*Q[state,action] \
            + alpha*(reward + discount_rate * np.max(Q[new_state, :]))      
        
        r_all += reward
        state = new_state
    
    reward_ls.append(r_all)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000


In [41]:
np.mean(reward_ls)

0.63675

In [43]:
Q

array([[0.528404  , 0.50741397, 0.50770333, 0.50583094],
       [0.34189007, 0.28714577, 0.27608253, 0.49266976],
       [0.38040395, 0.39802621, 0.39608381, 0.47822258],
       [0.28769833, 0.23721839, 0.30984258, 0.47199703],
       [0.54342758, 0.38892451, 0.38608399, 0.39901984],
       [0.        , 0.        , 0.        , 0.        ],
       [0.1855003 , 0.0975719 , 0.28421536, 0.12249907],
       [0.        , 0.        , 0.        , 0.        ],
       [0.26618867, 0.4467343 , 0.38241505, 0.56753255],
       [0.4328771 , 0.63534374, 0.36918807, 0.45948355],
       [0.64856715, 0.20794779, 0.38144504, 0.35997765],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.59358351, 0.3607154 , 0.75196581, 0.56039329],
       [0.73702809, 0.92627018, 0.81288763, 0.74177775],
       [0.        , 0.        , 0.        , 0.        ]])