In [1]:
import gym
import gym_digger
import numpy as np
import random
from gym import envs
from IPython.display import clear_output
import time

In [2]:
env = gym.make('DiggerDiscrete-v0')

In [3]:
class Agent():
    def __init__(self, env):
        self.action_size = env.action_space.n
        print('Action size:', self.action_size)
        
    def get_action(self):
        return random.choice(range(self.action_size))

In [4]:
class QAgent(Agent):
    def __init__(self, env, discount_rate=0.97, learning_rate=0.01):
        super().__init__(env)
        self.state_size = env.observation_space.n
        print('State size:', self.state_size)
        
        self.eps = 1.0
        self.discount_rate = discount_rate
        self.learning_rate = learning_rate
        self.build_model()
        
    def build_model(self):
        self.q_table = np.zeros((env.observation_space.n, env.action_space.n))
        
    def get_action(self, state):
        q_state = self.q_table[state]
        action_greedy = np.argmax(q_state)
        action_random = super().get_action()
        return action_random if random.random() < self.eps else action_greedy
    
    def train(self, experience):
        state, action, next_state, reward, done = experience
        
        q_next = self.q_table[next_state]
        q_next = np.zeros([self.action_size]) if done else q_next
        q_target = reward + self.discount_rate * np.max(q_next)
        
        q_update = q_target - self.q_table[state, action]
        self.q_table[state, action] += self.learning_rate * q_update
        
        if done:
            self.eps = self.eps * 0.99
        
agent = QAgent(env)

Action size: 5
State size: 64


In [10]:
total_reward = 0
for ep in range(100):
    state = env.reset()
    done = False
    while not done:
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        agent.train((state, action, next_state, reward, done))
        state = next_state
        total_reward += reward
        
        print('s:', state, 'a:', action)
        print('Episode: {}, Total reward: {}, eps: {}'.format(ep, total_reward, agent.eps))
        env.render()
        print(agent.q_table)
        time.sleep(0.5)
        clear_output(wait=True)

s: 34 a: 1
Episode: 14, Total reward: 56, eps: 0.03381119958765021
  (Down)
01
[41m1[0m1
[[ 1.17793826e-01  2.41039653e-02  1.46490189e-02  1.03941472e-01
   1.53877191e+00]
 [ 1.07075168e-02  9.47940081e-04  3.84138827e-04  1.50014695e-03
   1.48869536e-01]
 [ 3.84108513e-04  2.22512521e-03  6.68406265e-04  1.17688106e-02
   2.07464772e-01]
 [ 2.13738877e-03  4.78248703e-04  5.74278227e-04  2.89060000e-04
   3.94067939e-02]
 [ 0.00000000e+00  1.93030089e-04  0.00000000e+00  0.00000000e+00
   0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   1.00000000e-02]
 [ 9.70000000e-05  1.93030000e-04  9.40900000e-07  0.00000000e+00
   2.97010362e-02]
 [ 3.84129789e-04  0.00000000e+00  1.87239100e-06  0.00000000e+00
  -1.00000000e-02]
 [ 2.40193384e-03  6.18934764e-05  5.74287262e-04  1.69269767e-03
   1.78467838e-01]
 [ 9.33481283e-04  0.00000000e+00  0.00000000e+00  1.93030000e-04
   2.97028724e-02]
 [ 7.21756337e-05  1.96934587e-04  1.93038945e-04  1.683244

KeyboardInterrupt: 