In [3]:
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from cartpole_environment import CartPoleEnv

###SETUP###
gamma = 0.995
num_actions = 2
state_shape = 4
policy = 'egreedy'
experience_replay = True
decay_epsilon = True #This is to perform dynamic epsilon-greedy exploration by decaying the value of epsilon 
#with time

epsilon = 1  # Epsilon greedy parameter
epsilon_min = 0.01
decay_rate = 0.9995
batch_size = 20  # Size of batch taken from replay buffer

game = CartPoleEnv()

###BUILD THE ARCHITECTURE OF THE MODEL###
def build_architecture(learning_rate = 0.001):
    inputs = keras.Input(shape=(4,))
    x = layers.Dense(24, activation = 'relu')(inputs)   #Tried with 100 nodes also, but apparently there's no improvement
    x = layers.Dense(24, activation = 'relu')(x)
    x = layers.Dense(24, activation = 'relu')(x) #Let's see what happens when removing a layer
    outputs = layers.Dense(2, activation = 'linear')(x)

    model = keras.Model(inputs=inputs, outputs=outputs)
    optimizer = keras.optimizers.Adam(learning_rate)
    model.compile(optimizer = optimizer, loss = 'mse')
    return model

def select_action(state, policy, epsilon, model):
    if policy == 'egreedy':
            if epsilon > np.random.rand(1)[0]:
                action = random.randrange(game.action_space.n)
            else:
                # Predict action Q-values from environment state
                action_probs = model.predict(np.array([state,]))
                # Take best action
                action = np.argmax(action_probs)
    return action

def experience_replay_update(batch_size, len_history, state_history,state_next_history,
                             rewards_history, action_history, done_history, model):
    # Get indices of samples for replay buffers
    indices = np.random.choice(range(len_history), size = batch_size)

    # Using list comprehension to sample from replay buffer
    state_sample = np.array([state_history[i] for i in indices])
    state_next_sample = np.array([state_next_history[i] for i in indices])
    rewards_sample = np.array([rewards_history[i] for i in indices])
    action_sample = [action_history[i] for i in indices]
    done_sample = tf.convert_to_tensor([float(done_history[i]) for i in indices])

    # Build the updated Q-values for the sampled future states
    # Q value = reward + discount factor * expected future reward
            
    y_train = model.predict(state_sample)
    for i in range(len(y_train)):
        if not done_sample[i]:
            y_train[i][action_sample[i]] = rewards_sample[i] + gamma*np.max(model.predict(np.array([state_next_sample[i],])))
        else:
            y_train[i][action_sample[i]] = rewards_sample[i]
    #Train the model
    model.fit(state_sample, y_train, verbose = 0)

def cartpole(n_runs, learning_rate, gamma, policy, epsilon, experience_replay, batch_size, decay_epsilon):
    
    model = build_architecture(learning_rate)
    
    ###Experience replay buffers###
    action_history = []
    state_history = []
    state_next_history = []
    rewards_history = []
    done_history = []
    episode_reward_history = []
    running_reward = 0
    # Maximum replay length
    max_memory_length = 1000000
    # Train the model after a fixed number of actions
    run = 0

    for i in range(n_runs):  # Run until we end the budget
        state = game.reset()
        #state = np.reshape(state, [1,state_shape])
        #state = np.array([state,])
        episode_reward = 0
        run += 1
        n_steps = 0
    
        while True:
            #game.render() #Adding this line would show the attempts of the agent in a pop up window.
            n_steps +=1
            #Select an action according to the policy
            action = select_action(state, policy, epsilon, model)
            
            # Decay the probability of taking random action
            if decay_epsilon:
                epsilon *= decay_rate
                epsilon = max(epsilon, epsilon_min)

            # Apply the sampled action in our environment
            state_next, reward, done, _ = game.step(action)
            #state_next = np.reshape(state_next, [1,state_shape])

            episode_reward += reward

            # Save actions and states in replay buffer
            action_history.append(action)
            state_history.append(state)
            state_next_history.append(state_next)
            done_history.append(done)
            rewards_history.append(reward)
            state = state_next

            # Update every fixed number of frames and once batch size is reached
            if len(done_history) > batch_size and not done and experience_replay:
                len_history = len(done_history)
                experience_replay_update(batch_size, len_history, state_history, state_next_history, 
                                         rewards_history,
                                         action_history, done_history, model)

            # Limit the state and reward history
            if len(rewards_history) > max_memory_length:
                del rewards_history[:1]
                del state_history[:1]
                del state_next_history[:1]
                del action_history[:1]
                del done_history[:1]
            
            # If done print the score of current run
            if done:
                print("Run:" + str(run) + ", Steps:" + str(n_steps) + ", Epsilon:" + str(epsilon))
                break

        # Update running reward to check condition for solving
        episode_reward_history.append(episode_reward)

    return episode_reward_history


In [4]:
cartpole(n_runs=2000, learning_rate=0.01, gamma=0.995, policy = 'egreedy', epsilon=1, experience_replay=True, batch_size=20, decay_epsilon=True)

Run:1, Steps:12, Epsilon:0.9940164725309134
Run:2, Steps:12, Epsilon:0.9880687476628001
Run:3, Steps:23, Epsilon:0.9767682342250465
Run:4, Steps:15, Epsilon:0.9694680571640535
Run:5, Steps:14, Epsilon:0.9627037921120016
Run:6, Steps:27, Epsilon:0.9497914172412162
Run:7, Steps:15, Epsilon:0.9426928596981636
Run:8, Steps:20, Epsilon:0.9333105749632893
Run:9, Steps:29, Epsilon:0.9198718777420618
Run:10, Steps:14, Epsilon:0.9134536598864085
Run:11, Steps:18, Epsilon:0.9052674235521029
Run:12, Steps:17, Epsilon:0.8976033527310644
Run:13, Steps:22, Epsilon:0.887781380065634
Run:14, Steps:20, Epsilon:0.8789456096400869
Run:15, Steps:19, Epsilon:0.8706330950236377
Run:16, Steps:11, Epsilon:0.8658565662672015
Run:17, Steps:17, Epsilon:0.8585261511080049
Run:18, Steps:11, Epsilon:0.853816044322082
Run:19, Steps:15, Epsilon:0.8474347881728043
Run:20, Steps:10, Epsilon:0.8432071351729019
Run:21, Steps:39, Epsilon:0.8269198412196082
Run:22, Steps:15, Epsilon:0.820739602095697
Run:23, Steps:9, Epsil

Run:182, Steps:10, Epsilon:0.2569587434617572
Run:183, Steps:11, Epsilon:0.2555489982609623
Run:184, Steps:14, Epsilon:0.25376595740133856
Run:185, Steps:14, Epsilon:0.25199535734456957
Run:186, Steps:17, Epsilon:0.24986194326712088
Run:187, Steps:10, Epsilon:0.24861544075299546
Run:188, Steps:11, Epsilon:0.24725146916859517
Run:189, Steps:10, Epsilon:0.2460179896962517
Run:190, Steps:14, Epsilon:0.24430144949920127
Run:191, Steps:10, Epsilon:0.24308268698169508
Run:192, Steps:10, Epsilon:0.24187000458396357
Run:193, Steps:11, Epsilon:0.24054304028773116
Run:194, Steps:8, Epsilon:0.23958255024511305
Run:195, Steps:10, Epsilon:0.23838732920698225
Run:196, Steps:9, Epsilon:0.23731672921032318
Run:197, Steps:16, Epsilon:0.23542529829330847
Run:198, Steps:11, Epsilon:0.2341336913997524
Run:199, Steps:12, Epsilon:0.2327327460258233
Run:200, Steps:10, Epsilon:0.2315716970511487
Run:201, Steps:10, Epsilon:0.23041644027694702
Run:202, Steps:13, Epsilon:0.22892321830863213
Run:203, Steps:14, Ep

Run:358, Steps:33, Epsilon:0.0693099123640835
Run:359, Steps:17, Epsilon:0.06872312876496313
Run:360, Steps:13, Epsilon:0.06827776607521828
Run:361, Steps:9, Epsilon:0.06797112991139537
Run:362, Steps:16, Epsilon:0.06742939525574491
Run:363, Steps:25, Epsilon:0.06659156568694086
Run:364, Steps:13, Epsilon:0.06616001666783133
Run:365, Steps:11, Epsilon:0.06579704491320074
Run:366, Steps:33, Epsilon:0.064720034176266
Run:367, Steps:21, Epsilon:0.06404386088367206
Run:368, Steps:17, Epsilon:0.06350166012321694
Run:369, Steps:19, Epsilon:0.06290110137173838
Run:370, Steps:16, Epsilon:0.062399775197875006
Run:371, Steps:11, Epsilon:0.06205743314548643
Run:372, Steps:31, Epsilon:0.06110272236152543
Run:373, Steps:13, Epsilon:0.060706743987567355
Run:374, Steps:23, Epsilon:0.06001244272784582
Run:375, Steps:27, Epsilon:0.05920751896649502
Run:376, Steps:13, Epsilon:0.05882382252580734
Run:377, Steps:22, Epsilon:0.05818014625703373
Run:378, Steps:20, Epsilon:0.05760110007832936
Run:379, Steps:

KeyboardInterrupt: 