# Teaching A Deep Q Neural Network How To Balance
![Cat Balance](https://i.imgur.com/YUEtPEu.jpg)

In [26]:
# Dependencies
import gym
import random
import numpy as np
from collections import deque
from keras.layers import Dense
from keras.models import Sequential
from keras.optimizers import Adam
from mish import Mish

In [27]:
# Environment
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# Seed
seed = 123
env.seed(seed)
random.seed(seed)
np.random.seed(seed)

In [28]:
# Training parameterss
discount_factor = 0.99
epsilon_decay = 0.995
epsilon_min = 0.01
batch_size = 32
train_start = 1000
memory_size = 10000
n_episodes = 1000
n_win_ticks = 195

In [29]:
# Build model
def build_model():
    model = Sequential()
    model.add(Dense(96, input_dim=state_size, kernel_initializer='he_uniform'))
    model.add(Mish())
    model.add(Dense(48, kernel_initializer='he_uniform'))
    model.add(Mish())
    model.add(Dense(24, kernel_initializer='he_uniform'))
    model.add(Mish())
    model.add(Dense(action_size, kernel_initializer='he_uniform'))
    model.compile(Adam(), loss='mse')
    return model

model = build_model()
target_model = build_model()

model.summary()

Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_35 (Dense)             (None, 96)                480       
_________________________________________________________________
mish_29 (Mish)               (None, 96)                0         
_________________________________________________________________
dense_36 (Dense)             (None, 48)                4656      
_________________________________________________________________
mish_30 (Mish)               (None, 48)                0         
_________________________________________________________________
dense_37 (Dense)             (None, 24)                1176      
_________________________________________________________________
mish_31 (Mish)               (None, 24)                0         
_________________________________________________________________
dense_38 (Dense)             (None, 2)               

In [30]:
# Training helpers
# Source: https://github.com/yanpanlau/CartPole/blob/master/DQN/CartPole_DQN.py
def update_target_model():
    target_model.set_weights(model.get_weights())

def get_action(state, epsilon):
    if np.random.rand() <= epsilon:
        return random.randrange(action_size)
    else:
        q_value = model.predict(state)
        return np.argmax(q_value[0])

def train_replay():
    if len(memory) < train_start:
        return
    minibatch = random.sample(memory,  min(batch_size, len(memory)))
    state_t, action_t, reward_t, state_t1, terminal = zip(*minibatch)
    state_t = np.concatenate(state_t)
    state_t1 = np.concatenate(state_t1)
    targets = model.predict(state_t)
    Q_sa = target_model.predict(state_t1)
    targets[range(batch_size), action_t] = reward_t + discount_factor*np.max(Q_sa, axis=1)*np.invert(terminal)
    model.train_on_batch(state_t, targets)

In [31]:
# Training
scores = deque(maxlen=100)
episodes = []
memory = deque(maxlen=memory_size)

def learn_to_balance():
    epsilon = 1.0 # Start with randomness

    for e in range(n_episodes):
        done = False
        score = 0
        state = env.reset()
        state = np.reshape(state, [1, state_size])

        while not done:
            action = get_action(state, epsilon)
            next_state, reward, done, info = env.step(action)
            next_state = np.reshape(next_state, [1, state_size])

            memory.append((state, action, reward if not done else -100, next_state, done))
            if epsilon > epsilon_min:
                epsilon *= epsilon_decay # Decrease randomness
            train_replay()
            score += reward
            state = next_state
            
            env.render()

            if done:
                env.reset()
                update_target_model()
                scores.append(score)
                episodes.append(e)

                if e % 100 == 0:
                    print('[Episode {}] Score: {}'.format(e, score))

                if np.mean(scores) >= n_win_ticks:
                    print('[Episode {}] Solved! \o/'.format(e - 100))
                    return
learn_to_balance()

[Episode 0] Score: 52.0
[Episode 100] Score: 178.0
[Episode 22] Solved! \o/
