In [1]:
import gym
import random
import numpy as np
import time
from gym.envs.registration import register
from IPython.display import clear_output

In [2]:
try:
    register(
        id='FrozenLakeNoSlip-v0',
        entry_point='gym.envs.toy_text:FrozenLakeEnv',
        kwargs={'map_name': '4x4', 'is_slippery': False},
        max_episode_steps=100,
        reward_threshold=0.78, # optimum = .8196
    )
except:
    pass

In [3]:
env_name = 'FrozenLake-v0'
env_name = 'FrozenLakeNoSlip-v0'
env = gym.make(env_name)
print("Observation Space: ", env.observation_space)
print("Action Space: ", env.action_space)
type(env.action_space)

Observation Space:  Discrete(16)
Action Space:  Discrete(4)


gym.spaces.discrete.Discrete

In [4]:
class Agent():
    def __init__(self, env):
        self.is_discrete = \
            type(env.action_space) == gym.spaces.discrete.Discrete
        
        if self.is_discrete:
            self.action_size = env.action_space.n
            print("Action size: ", self.action_size)
        else:
            self.action_low = env.action_space.low
            self.action_high = env.action_space.high
            self.action_shape = env.action_space.shape
            print("Action range: ", self.action_low, self.action_high)
            
    def get_action(self, state):
        if self.is_discrete:
            action = random.choice(range(self.action_size))
        else:
            action = np.random.uniform(self.action_low,
                                       self.action_high,
                                       self.action_shape)
            
        return action

In [5]:
class QAgent(Agent):
    def __init__(self, env, discount_rate=0.97, learning_rate=0.01):
        super().__init__(env)
        self.state_size = env.observation_space.n
        print("State Size:", self.state_size)
        
        self.eps = 1.0
        self.discount_rate = discount_rate
        self.learning_rate = learning_rate
        self.build_model()
        
    def build_model(self):
        self.q_table = 1e-4*np.random.random([self.state_size, self.action_size])
    
    def get_action(self, state):
        q_state = self.q_table[state]
        action_greedy = np.argmax(q_state)
        action_random = super().get_action(state)
        return action_random if random.random() < self.eps else action_greedy
    
    def train(self, experience):
        state, action, next_state, reward, done = experience
        
        q_next = self.q_table[next_state]
        q_next = np.zeros([self.action_size]) if done else q_next
        q_target = reward + self.discount_rate * np.max(q_next)
        
        q_update = q_target - self.q_table[state, action]
        self.q_table[state, action] += self.learning_rate * q_update
        
        if done:
            self.eps = self.eps * 0.99

In [12]:
agent = QAgent(env)

total_reward = 0
for ep in range(100):
    state = env.reset()
    done = False
    while not done:
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        agent.train((state, action, next_state, reward, done))
        state = next_state
        total_reward += reward
        
        print("s:", state, "a:", action)
        print("ep:", ep, "tot_rew:", total_reward, "eps:", agent.eps)
        env.render()
        print(agent.q_table)
        time.sleep(.02)
        clear_output(wait=True)

s: 0 a: 0
ep: 83 tot_rew: 0.0
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
[[8.33561162e-06 5.12818837e-05 5.12259548e-05 4.07247380e-05]
 [4.66592077e-05 4.26974262e-05 9.51821899e-06 2.69056956e-05]
 [7.98525197e-05 8.38062792e-06 6.88922399e-05 2.43633561e-05]
 [5.79314120e-05 5.04567712e-05 7.44176197e-06 8.92869480e-05]
 [5.16280963e-05 2.83755963e-05 5.15184789e-05 5.16163533e-05]
 [4.09307785e-05 1.37950313e-05 1.63877185e-05 3.31444282e-05]
 [7.19612541e-05 4.22193278e-05 2.10878238e-05 8.27001860e-05]
 [8.12852505e-05 7.22922920e-05 9.96999674e-05 5.73798422e-05]
 [7.99810816e-05 3.47604224e-05 3.79205755e-05 9.28305032e-05]
 [9.10515412e-05 3.57529339e-05 9.89715552e-05 8.88576079e-05]
 [7.03631226e-05 8.96395872e-05 5.54959399e-05 8.66973197e-05]
 [8.74493926e-05 3.22738629e-06 4.59060735e-05 3.28876314e-05]
 [4.24003916e-05 9.04953823e-05 3.71108614e-05 4.36785863e-05]
 [4.34765470e-05 4.62316712e-05 1.78746240e-06 7.96503645e-05]
 [1.98381552e-05 2.96076134e-05 7.94099568e-05 7.3

KeyboardInterrupt: 