### Imports

In [1]:
import gym
import numpy as np
import random
import collections
import keras
from keras.layers import Dense
from collections import deque

Using TensorFlow backend.


#### Set up agent environment

In [13]:
def cartpole():
    env = gym.make('CartPole-v0')
    obs_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn = DQN(obs_space, action_space)
    
    scores = []
    
    for episode in range(20):
        state = env.reset()
        state.reshape([1, obs_space])

        time_step = 0
        done = False
        
        while not done:
            env.render()
            
            action = dqn.act(state)
            
            observation, reward, done, info = env.step(action)
            
            if done:
                print('Terminal observation: ', observation)
                
            time_step += 1
            
            dqn.experience_replay()
                
        scores.append(time_step)
        
        mean_survival_time = np.mean(scores)
        

    print('Mean survival time: ', mean_survival_time)
    
    env.close()

#### Set up neural network architecture

In [14]:
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

class DQN:
    
    def __init__(self, 
                 obs_space, 
                 action_space, 
                 lr = 0.001, 
                 gamma = 0.95, 
                 exploration_rate = 1.0, 
                 expl_min = 0.01,
                 expl_decay = 0.995
                ):
        self.expl_rate = 1.0
        self.action_space = action_space
        
        self.memory = deque(maxlen = 2000)
        self.learning_rate = lr
        self.gamma = gamma
        self.epsilon = exploration_rate
        self.epsilon_min = expl_min
        self.epsilon_decay = expl_decay
        
        self.model = Sequential()
        self.model.add(Dense(24, input_dim = obs_space, activation = 'relu'))
        self.model.add(Dense(24, activation = 'relu'))
        self.model.add(Dense(self.action_space, activation = 'linear'))
        
        self.model.compile(loss = 'mse', optimizer = Adam(lr = self.learning_rate))
        
    def remember(self, state, action, reward, next_state, done):
        self.memory.append(state, action, reward, next_state, done)
        
    def act(self, state):
        if np.random.rand() < self.epsilon:
            return random.randrange(self.action_space)
        
        q_vals = self.model.predict(state)
        
        print(q_vals)
        
    def experience_replay(self):
        if len(self.memory) < 20:
            return
        
        batch = random.sample(self.memory, 20)
        
        for state, action, reward, state_next, done in batch:
            q_update = reward
            
            if not done:
                q_update = (reward + 0.95 * np.argmax(self.model.predict(state_next)[0]))
                
            q_values = self.model.predict(state)
            
            q_values[0][action] = q_update
            
            self.model.fit(state, q_values, verbose = 0)
            
            self.exlporation_rate *= 0.995
            self.exploration_rate = np.argmax(self.exploration_rate, 0.01)
            
        
cartpole()

Terminal observation:  [-0.15445091 -1.16676791  0.21486145  1.92812715]
Terminal observation:  [ 0.34298423  1.75281896 -0.23858616 -2.11812046]
Terminal observation:  [-0.1349107  -1.74426555  0.21312319  2.80397714]
Terminal observation:  [-0.12267835 -0.93037528  0.22015555  1.54864251]
Terminal observation:  [ 0.02935208  0.17540226 -0.22117578 -0.92957897]
Terminal observation:  [-0.15592059 -1.01631017  0.21710727  1.60227923]
Terminal observation:  [-0.14638083 -0.43051265  0.21737366  0.98139304]
Terminal observation:  [-0.16231649 -1.21589824  0.21154089  1.91307441]
Terminal observation:  [ 0.14606858  0.56789618 -0.22507595 -1.26620572]
Terminal observation:  [-0.09071081 -0.07059813  0.21378475  0.94686404]
Terminal observation:  [-0.20862974 -1.73171872  0.24468498  2.4973602 ]
Terminal observation:  [ 0.09639624  1.20249048 -0.22121407 -2.13467393]
Terminal observation:  [-0.12410561 -0.57464313  0.22789974  1.37826173]
Terminal observation:  [-0.161637   -0.83206061  0.