In [3]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam


from scores.score_logger import ScoreLogger

ENV_NAME = "CartPole-v1"

GAMMA = 0.95
LEARNING_RATE = 0.001

MEMORY_SIZE = 1000000
BATCH_SIZE = 20

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995


class DQNSolver:

    def __init__(self, observation_space, action_space):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.model = Sequential()
        self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
        self.model.add(Dense(24, activation="relu"))
        self.model.add(Dense(self.action_space, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        q_values = self.model.predict(state)
        return np.argmax(q_values[0])

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
            q_values = self.model.predict(state)
            q_values[0][action] = q_update
            self.model.fit(state, q_values, verbose=0)
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)


def cartpole():
    env = gym.make(ENV_NAME)
    score_logger = ScoreLogger(ENV_NAME)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn_solver = DQNSolver(observation_space, action_space)
    run = 0
    while True:
        run += 1
        state = env.reset()
        state = np.reshape(state, [1, observation_space])
        step = 0
        while True:
            step += 1
            #env.render()
            action = dqn_solver.act(state)
            state_next, reward, terminal, info = env.step(action)
            reward = reward if not terminal else -reward
            state_next = np.reshape(state_next, [1, observation_space])
            dqn_solver.remember(state, action, reward, state_next, terminal)
            state = state_next
            if terminal:
                print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
                score_logger.add_score(step, run)
                break
            dqn_solver.experience_replay()


if __name__ == "__main__":
    cartpole()

W1226 15:56:56.644364 4569554240 deprecation_wrapper.py:119] From /Users/jashrathod/anaconda3/envs/tensorflow_env/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W1226 15:56:56.666205 4569554240 deprecation_wrapper.py:119] From /Users/jashrathod/anaconda3/envs/tensorflow_env/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W1226 15:56:56.688163 4569554240 deprecation_wrapper.py:119] From /Users/jashrathod/anaconda3/envs/tensorflow_env/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W1226 15:56:56.746692 4569554240 deprecation_wrapper.py:119] From /Users/jashrathod/anaconda3/envs/tensorflow_env/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Opt

Run: 1, exploration: 0.9369146928798039, score: 33
Scores: (min: 33, avg: 33, max: 33)

Run: 2, exploration: 0.8955869907338783, score: 10


  z = np.polyfit(np.array(trend_x), np.array(y[1:]), 1)


Scores: (min: 10, avg: 21.5, max: 33)

Run: 3, exploration: 0.8390886103705794, score: 14
Scores: (min: 10, avg: 19, max: 33)

Run: 4, exploration: 0.7552531090661897, score: 22
Scores: (min: 10, avg: 19.75, max: 33)

Run: 5, exploration: 0.6935613678313175, score: 18
Scores: (min: 10, avg: 19.4, max: 33)

Run: 6, exploration: 0.5647174463480732, score: 42
Scores: (min: 10, avg: 23.166666666666668, max: 42)

Run: 7, exploration: 0.531750826943791, score: 13
Scores: (min: 10, avg: 21.714285714285715, max: 42)

Run: 8, exploration: 0.5082950737585841, score: 10
Scores: (min: 10, avg: 20.25, max: 42)

Run: 9, exploration: 0.4858739637363176, score: 10
Scores: (min: 10, avg: 19.11111111111111, max: 42)

Run: 10, exploration: 0.4598090507939749, score: 12
Scores: (min: 10, avg: 18.4, max: 42)

Run: 11, exploration: 0.3403786882983606, score: 61
Scores: (min: 10, avg: 22.272727272727273, max: 61)

Run: 12, exploration: 0.30945741577570285, score: 20
Scores: (min: 10, avg: 22.083333333333332,

Run: 91, exploration: 0.01, score: 231
Scores: (min: 10, avg: 109.78021978021978, max: 322)

Run: 92, exploration: 0.01, score: 219
Scores: (min: 10, avg: 110.96739130434783, max: 322)

Run: 93, exploration: 0.01, score: 256
Scores: (min: 10, avg: 112.52688172043011, max: 322)

Run: 94, exploration: 0.01, score: 192
Scores: (min: 10, avg: 113.37234042553192, max: 322)

Run: 95, exploration: 0.01, score: 246
Scores: (min: 10, avg: 114.76842105263158, max: 322)

Run: 96, exploration: 0.01, score: 199
Scores: (min: 10, avg: 115.64583333333333, max: 322)

Run: 97, exploration: 0.01, score: 344
Scores: (min: 10, avg: 118, max: 344)

Run: 98, exploration: 0.01, score: 258
Scores: (min: 10, avg: 119.42857142857143, max: 344)

Run: 99, exploration: 0.01, score: 214
Scores: (min: 10, avg: 120.38383838383838, max: 344)

Run: 100, exploration: 0.01, score: 201
Scores: (min: 10, avg: 121.19, max: 344)

Run: 101, exploration: 0.01, score: 287
Scores: (min: 10, avg: 123.73, max: 344)

Run: 102, expl

NameError: name 'exit' is not defined