In [1]:
import random
import gym
import numpy as np
import sys
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from tqdm import tqdm

# sys.path.append('sunblaze_envs')
# import sunblaze_envs

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
GAMMA = 1
LEARNING_RATE = 0.001

MEMORY_SIZE = 50000
BATCH_SIZE = 32

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995

In [3]:
class Dqn:

    def __init__(self, observation_space, action_space):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.model = Sequential()
        self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
        self.model.add(Dense(24, activation="relu"))
        self.model.add(Dense(self.action_space, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        q_values = self.model.predict(state)
        return np.argmax(q_values[0])

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
            q_values = self.model.predict(state)
            q_values[0][action] = q_update
            self.model.fit(state, q_values, verbose=0)
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
        

def train(env_name, score=199, training_episodes=10000000):
#     env = sunblaze_envs.make(env_name)
    env = gym.make(env_name)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn = Dqn(observation_space, action_space)
    scores = []
    all_scores = []
    for i in range(1,training_episodes):
        state = env.reset()
        state = np.reshape(state, [1, observation_space])
        step = 0
        while True:
            step += 1
            action = dqn.act(state)
            state_next, reward, terminal, info = env.step(action)
            reward = reward if not terminal else -reward
            state_next = np.reshape(state_next, [1, observation_space])
            dqn.remember(state, action, reward, state_next, terminal)
            state = state_next
            if terminal:
                print("Run: " + str(i) + ", exploration: " + str(dqn.exploration_rate) + ", score: " + str(step))
                scores.append(step)
                all_scores.append(step)
                break
            dqn.experience_replay()
        if (i % 100 == 0):
            avg = np.array(scores).mean()
            print("===============")
            print("Episodes: " + str(i) + ", mean 100 episodes reward: " + str(avg))
            print("===============")
            scores = []
            if avg >= score:
                break
            
    plt.plot(all_scores)
    plt.ylabel('Reward')
    plt.show()
    return dqn

def play(dqn, env_name):
    env = gym.make(env_name)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    state = env.reset()
    state = np.reshape(state, [1, observation_space])
    step = 0
    while True:
        step += 1
        env.render()
        action = dqn.act(state)
        state_next, reward, terminal, info = env.step(action)
        reward = reward if not terminal else -reward
        state_next = np.reshape(state_next, [1, observation_space])
        state = state_next
        if terminal:
            print("Score: " + str(step))
            break

In [None]:
# dqn = train('SunblazeCartPole-v0')
env = "CartPole-v1"
dqn = train(env)





Run: 1, exploration: 1.0, score: 32


Run: 2, exploration: 0.918316468354365, score: 18
Run: 3, exploration: 0.8560822709551227, score: 15
Run: 4, exploration: 0.7514768435208588, score: 27
Run: 5, exploration: 0.6730128848950395, score: 23
Run: 6, exploration: 0.6337242817644086, score: 13
Run: 7, exploration: 0.5704072587541458, score: 22
Run: 8, exploration: 0.5082950737585841, score: 24
Run: 9, exploration: 0.4858739637363176, score: 10
Run: 10, exploration: 0.457510005540005, score: 13
Run: 11, exploration: 0.43952667968844233, score: 9
Run: 12, exploration: 0.42013897252428334, score: 10
Run: 13, exploration: 0.3995984329713264, score: 11
Run: 14, exploration: 0.38389143477919885, score: 9
Run: 15, exploration: 0.3614809303671764, score: 13
Run: 16, exploration: 0.2649210072611673, score: 63
Run: 17, exploration: 0.22566020663225933, score: 33
Run: 18, exploration: 0.19415447453059972, score: 31
Run: 19, exploration: 0.15888051309497406, score: 41
Run: 20, exploration: 0.1176

Run: 161, exploration: 0.01, score: 11
Run: 162, exploration: 0.01, score: 11
Run: 163, exploration: 0.01, score: 10
Run: 164, exploration: 0.01, score: 12
Run: 165, exploration: 0.01, score: 9
Run: 166, exploration: 0.01, score: 10
Run: 167, exploration: 0.01, score: 8
Run: 168, exploration: 0.01, score: 10
Run: 169, exploration: 0.01, score: 10
Run: 170, exploration: 0.01, score: 10
Run: 171, exploration: 0.01, score: 10
Run: 172, exploration: 0.01, score: 9
Run: 173, exploration: 0.01, score: 10
Run: 174, exploration: 0.01, score: 9
Run: 175, exploration: 0.01, score: 11
Run: 176, exploration: 0.01, score: 8
Run: 177, exploration: 0.01, score: 8
Run: 178, exploration: 0.01, score: 10
Run: 179, exploration: 0.01, score: 8
Run: 180, exploration: 0.01, score: 10
Run: 181, exploration: 0.01, score: 10
Run: 182, exploration: 0.01, score: 9
Run: 183, exploration: 0.01, score: 9
Run: 184, exploration: 0.01, score: 9
Run: 185, exploration: 0.01, score: 9
Run: 186, exploration: 0.01, score: 

In [None]:
play(dqn, env)