In [1]:
import gym
import random
import os
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

In [2]:
# parameters
isTrain = False
render = False
weights = 'cartpole.h5'
min_explore = 0.01
explore_rate = 1.0
decay = 0.995
num_episodes =10000
num_evals = 100
gamma = 0.95
batch_size = 32

In [3]:
# make model
env = gym.make('CartPole-v0')
history = deque(maxlen=1000)

model = Sequential()
model.add(Dense(24, input_dim=4, activation='relu'))
model.add(Dense(24, activation='relu'))
model.add(Dense(2, activation='linear'))
model.compile(loss='mse', optimizer=Adam(lr=0.001))
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 24)                120       
_________________________________________________________________
dense_1 (Dense)              (None, 24)                600       
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 50        
Total params: 770
Trainable params: 770
Non-trainable params: 0
_________________________________________________________________


In [4]:
# load saved weights in not in training mode
if(not isTrain):
    if os.path.isfile(weights):
        model.load_weights(weights)
        explore_rate = min_explore

In [5]:
# function to update model
def update_model():
    global explore_rate
    if len(history) < batch_size: #Start training only when queue has min batch size entries
        return
    sample_batch = random.sample(history, batch_size)
    for state, action, reward, next_state, done in sample_batch:
        target = reward
        if not done:
          target = reward + gamma * np.amax(model.predict(next_state)[0]) #Get target from next state prediction
        target_f = model.predict(state) # from current state
        target_f[0][action] = target 
        model.fit(state, target_f, epochs=1, verbose=0)
    if explore_rate > min_explore: # Reduce the explore rate after each batch train
        explore_rate *= decay

In [6]:
# Function to infer model
def infer(state):
    if np.random.rand() <= explore_rate: #Use Model or Random Number based on current explore rate
        return random.randrange(2) #Random number 0 or 1
    vals = model.predict(state)
    return np.argmax(vals[0])

In [7]:
# function to train model
def run_train():
    try:
        for index_episode in range(1,num_episodes+1):
            state = env.reset()
            state = np.reshape(state, [1, 4])
            done = False
            currRewards = 0
            while not done:
                if(render):
                    env.render()
                action = infer(state)
                next_state, reward, done, _ = env.step(action)
                next_state = np.reshape(next_state, [1, 4])
                history.append((state, action, reward, next_state, done))
                state = next_state
                currRewards += reward
            print("Episode {}# Score: {}".format(index_episode, currRewards))
            update_model()
    finally:
        model.save(weights)

In [8]:
# function to evaluate model
# cartpole is solved if avg score is greater than 195 for 100 consecutive runs
def eval():
    AvgReward = 0
    for index_episode in range(1,num_evals+1):
        state = env.reset()
        state = np.reshape(state, [1, 4])
        done = False
        currEvalRewards = 0
        while not done:
            if(render):
                env.render()
            action = infer(state)
            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, 4])
            state = next_state
            currEvalRewards += reward
        AvgReward += currEvalRewards
        print("Episode {}# Score: {}".format(index_episode, currEvalRewards))
        print("Avg Score: {}".format(AvgReward/(index_episode)))

In [9]:
if(isTrain):
    run_train()
else:
    eval()

Episode 1# Score: 200.0
Avg Score: 200.0
Episode 2# Score: 200.0
Avg Score: 200.0
Episode 3# Score: 200.0
Avg Score: 200.0
Episode 4# Score: 200.0
Avg Score: 200.0
Episode 5# Score: 200.0
Avg Score: 200.0
Episode 6# Score: 200.0
Avg Score: 200.0
Episode 7# Score: 200.0
Avg Score: 200.0
Episode 8# Score: 200.0
Avg Score: 200.0
Episode 9# Score: 200.0
Avg Score: 200.0
Episode 10# Score: 200.0
Avg Score: 200.0
Episode 11# Score: 200.0
Avg Score: 200.0
Episode 12# Score: 200.0
Avg Score: 200.0
Episode 13# Score: 200.0
Avg Score: 200.0
Episode 14# Score: 200.0
Avg Score: 200.0
Episode 15# Score: 200.0
Avg Score: 200.0
Episode 16# Score: 200.0
Avg Score: 200.0
Episode 17# Score: 200.0
Avg Score: 200.0
Episode 18# Score: 200.0
Avg Score: 200.0
Episode 19# Score: 200.0
Avg Score: 200.0
Episode 20# Score: 200.0
Avg Score: 200.0
Episode 21# Score: 200.0
Avg Score: 200.0
Episode 22# Score: 200.0
Avg Score: 200.0
Episode 23# Score: 200.0
Avg Score: 200.0
Episode 24# Score: 200.0
Avg Score: 200.0
E