

<p><img height="80px" src="https://www.upm.es/sfs/Rectorado/Gabinete%20del%20Rector/Logos/UPM/Escudo/EscUpm.jpg" align="left" hspace="0px" vspace="0px"></p>

**Course "Artificial Neural Networks and Deep Learning" - Universidad Politécnica de Madrid (UPM)**

# **Deep Q-Learning for Cartpole**

This notebook includes an implementation of the Deep Q-learning (DQN) algorithm for the cartpole problem (see [OpenAI's Cartpole](https://gym.openai.com/envs/CartPole-v1/)).


##Libraries

In [None]:
#!pip install gym[Box_2D]
#!pip install box2d-py
#!pip install pyglet

In [1]:
import gym
import numpy as np
from tensorflow import keras
import matplotlib.pyplot as plt
import time
import keras.backend as K
from tensorflow.keras.regularizers import l2
import random
import tensorflow as tf

In [2]:
physical_devices = tf.config.list_physical_devices('GPU') 
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

## Hyperparameters

In [3]:
GAMMA = 0.99

MEMORY_SIZE = 200000
LEARNING_RATE = 0.001
BATCH_SIZE = 128
EXPLORATION_MAX = 1
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
NUMBER_OF_EPISODES = 2000
MAX_STEPS = 1000
K_STEPS = 1000
TRAIN_STEPS = 4
REGULARIZER_FACTOR = 0.001

## Class ReplayMemory

Memory of transitions for experience replay.

In [4]:
class ReplayMemory:

    def __init__(self,number_of_observations):
        # Create replay memory
        self.states = np.zeros((MEMORY_SIZE, number_of_observations))
        self.states_next = np.zeros((MEMORY_SIZE, number_of_observations))
        self.actions = np.zeros(MEMORY_SIZE, dtype=np.int32)
        self.rewards = np.zeros(MEMORY_SIZE)
        self.terminal_states = np.zeros(MEMORY_SIZE, dtype=bool)
        self.current_size=0

    def store_transition(self, state, action, reward, state_next, terminal_state):
        # Store a transition (s,a,r,s') in the replay memory
        i = self.current_size
        self.states[i] = state
        self.states_next[i] = state_next
        self.actions[i] = action
        self.rewards[i] = reward
        self.terminal_states[i] = terminal_state
        self.current_size = i + 1
        
        if self.current_size >= MEMORY_SIZE - 1:
            self.current_size = 0
            

    def sample_memory(self, batch_size):
        # Generate a sample of transitions from the replay memory
        batch = np.random.choice(self.current_size, batch_size)
        states = self.states[batch]
        states_next = self.states_next[batch]
        rewards = self.rewards[batch]
        actions = self.actions[batch]   
        terminal_states = self.terminal_states[batch]  
        return states, actions, rewards, states_next, terminal_states

## Class DQN

Reinforcement learning agent with a Deep Q-Network.

In [5]:
class DQN:

    def __init__(self, number_of_observations, number_of_actions):
        # Initialize variables and create neural model
        self.exploration_rate = EXPLORATION_MAX
        self.number_of_actions = number_of_actions
        self.number_of_observations = number_of_observations
        self.scores = []
        self.memory = ReplayMemory(number_of_observations)
        self.model = keras.models.Sequential()
        self.model.add(keras.layers.Dense(64, input_shape=(number_of_observations,),
                                          activation="relu" ,kernel_initializer="he_normal"))
        self.model.add(keras.layers.Dense(64, activation="relu" ,kernel_initializer="he_normal"))
        self.model.add(keras.layers.Dense(64, activation="relu" ,kernel_initializer="he_normal"))
        self.model.add(keras.layers.Dense(number_of_actions, activation="linear", kernel_regularizer=l2(REGULARIZER_FACTOR)))
        self.model.compile(loss="mse", optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE))
        # self.model.compile(loss="mse", optimizer=keras.optimizers.SGD(learning_rate=LEARNING_RATE, momentum = 0.9))
        self.model.save("model1.h5")
        self.target_model = keras.models.load_model("model1.h5")
        #self.target_model.compile(loss="mse", optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE))
        # self.model.compile(loss="mse", optimizer=keras.optimizers.SGD(learning_rate=LEARNING_RATE, momentum = 0.9))

    def masked_huber_loss(self, mask_value, clip_delta):
        def f(y_true, y_pred):
            error = y_true - y_pred
            cond  = K.abs(error) < clip_delta
            mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
            masked_squared_error = 0.5 * K.square(mask_true * (y_true - y_pred))
            linear_loss  = mask_true * (clip_delta * K.abs(error) - 0.5 * (clip_delta ** 2))
            huber_loss = tf.where(cond, masked_squared_error, linear_loss)
            return K.sum(huber_loss) / K.sum(mask_true)
        f.__name__ = 'masked_huber_loss'
        return f    
    
    def remember(self, state, action, reward, next_state, terminal_state):
        # Store a tuple (s, a, r, s') for experience replay
        state = np.reshape(state, [1, self.number_of_observations])
        next_state = np.reshape(next_state, [1, self.number_of_observations])
        self.memory.store_transition(state, action, reward, next_state, terminal_state)

    def select(self, state):
        # Generate an action for a given state using epsilon-greedy policy
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.number_of_actions)
        else:
            state = np.reshape(state, [1, self.number_of_observations])
            q_values = self.model.predict(state)
            return np.argmax(q_values[0])

    def learn(self, step):
        # Learn the value Q using a sample of examples from the replay memory
        if self.memory.current_size < BATCH_SIZE: return

        states, actions, rewards, next_states, terminal_states = self.memory.sample_memory(BATCH_SIZE)

        q_targets = self.model.predict(states)
        print(np.shape(q_targets))
        print(q_targets)
        q_next_states = self.target_model.predict(next_states)

        for i in range(BATCH_SIZE):
            if (terminal_states[i]):
                q_targets[i][actions[i]] = rewards[i]
            else:
                q_targets[i][actions[i]] = rewards[i] + GAMMA * np.max(q_next_states[i])

        self.model.train_on_batch(states, q_targets)

        # Copy model to target model
        if total_steps % K_STEPS == 0:
            self.model.save("model1.h5")
            self.target_model = keras.models.load_model("model1.h5")
            #self.target_model.compile(loss="mse", optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE))

    def add_score(self, score):
        # Add the obtained score in a list to be presented later
        self.scores.append(score)

    def display_scores_graphically(self):
        # Display the obtained scores graphically
        plt.plot(self.scores)
        plt.xlabel("Episode")
        plt.ylabel("Score")         

## Environment Cartpole

Cartpole simulator from [Open Ai Gym](https://gym.openai.com/envs/CartPole-v1/):

<p><img height="200px" src="https://raw.githubusercontent.com/martin-molina/reinforcement_learning/main/images/cartpole_attributes.png" align="center" vspace="20px"</p>

State vector:
- state[0]: position
- state[1]: velocity
- state[2]: angle
- state[3]: angular velocity

Actions:
- 0 (push cart to the left)
- 1 (push cart to the right)

In [6]:
def create_environment():
    # Create simulated environment
    environment = gym.make("LunarLander-v2")
    number_of_observations = environment.observation_space.shape[0]
    number_of_actions = environment.action_space.n
    return environment, number_of_observations, number_of_actions

## Main program




In [7]:
environment, number_of_observations, number_of_actions = create_environment()
agent = DQN(number_of_observations, number_of_actions)
episode = 0
goal_reached = False
start_time = time.perf_counter()
scores = []
total_steps = 1
while (episode < NUMBER_OF_EPISODES) and not (goal_reached):
    episode += 1
    step = 1
    end_episode = False
    state = environment.reset()
    reward_accumulated = 0
    # Decrease exploration rate
    agent.exploration_rate *= EXPLORATION_DECAY
    agent.exploration_rate = max(EXPLORATION_MIN, agent.exploration_rate)
    while not (end_episode):
        # Select an action for the current state
        action = agent.select(state)

        # Execute the action on the environment
        state_next, reward, terminal_state, info = environment.step(action)

        if episode == NUMBER_OF_EPISODES:
            environment.render()

        # Store in memory the transition (s,a,r,s')
        agent.remember(state, action, reward, state_next, terminal_state)

        # print(state[0],reward)

        # Learn using a batch of experience stored in memory
        if total_steps % TRAIN_STEPS == 0:
            agent.learn(step)
        

        # Detect end of episode
        if terminal_state:
            print('Last reward: ', reward)
            reward_accumulated += reward
            agent.add_score(reward_accumulated)
            scores.append(reward_accumulated)
            if episode < 100:
                avg_score = sum(scores)/len(scores)
            else:
                avg_score = sum(scores[-100:])/100
            if avg_score >= 200:
                goal_reached = True
            print("Episode {0:>3}: ".format(episode), end='')
            print("Average score {0:>3} ".format(avg_score), end='')
            # print("estado: ", state)
            print("(exploration rate: %.2f, " % agent.exploration_rate, end='')
            print("transitions: " + str(agent.memory.current_size) + ")")
            print("Score: ", reward_accumulated)
            end_episode = True
        else:
            state = state_next
            step += 1
            total_steps += 1
            reward_accumulated += reward

environment.close()
if goal_reached:
    print("Reached goal sucessfully.")
else:
    print("Failure to reach the goal.")

print("Time:", round((time.perf_counter() - start_time) / 60), "minutes")

agent.display_scores_graphically()

Last reward:  -100
Episode   1: Average score -215.84796644733208 (exploration rate: 0.99, transitions: 99)
Score:  -215.84796644733208
(128, 4)
[[-2.14213    -1.7010877  -1.1015795  -0.17411038]
 [-1.6233885  -0.11200303 -0.5168948   0.21294129]
 [-1.7096754  -0.0366264  -0.4361338   0.1409618 ]
 [-2.0641413  -1.2373321  -0.7470242  -0.23962435]
 [-1.6350691  -0.20810497 -0.4423628  -0.06483456]
 [-1.8676782  -0.0416885  -0.44675475  0.25851977]
 [-1.8805354  -0.02990073 -0.4133581   0.1872369 ]
 [-1.6157124  -0.2003442  -0.34814328 -0.15613785]
 [-1.8287106  -0.03257149 -0.40786594  0.14515105]
 [-1.6157124  -0.2003442  -0.34814328 -0.15613785]
 [-1.5896394  -0.08875144 -0.40003002  0.36609954]
 [-2.3148963  -1.5840399  -0.8967421   0.3807979 ]
 [-1.7134792  -0.07183081 -0.33461696 -0.04784599]
 [-1.6514856  -0.08489418 -0.40208405  0.02909502]
 [-2.2058332  -1.400992   -0.8840779   0.13408503]
 [-1.8676782  -0.0416885  -0.44675475  0.25851977]
 [-2.5626278  -2.181586   -1.5297406  -

(128, 4)
[[-1.3199965   0.29929236 -0.49537897 -0.3823024 ]
 [-1.4296538   0.24398047 -0.5827315  -0.71319354]
 [-1.4853748  -0.4777949  -0.82891095 -1.4186916 ]
 [-1.3640152   0.32152867 -0.56424063 -0.48912895]
 [-1.3640152   0.32152867 -0.56424063 -0.48912895]
 [-1.4543362   0.29564112 -0.56377316 -0.7994976 ]
 [-1.3199965   0.29929236 -0.49537897 -0.3823024 ]
 [-1.7405835  -0.5137127  -0.8796437  -1.1723578 ]
 [-1.3185071  -0.20585012 -0.68001366 -1.2064795 ]
 [-1.4806339  -0.01696009 -0.5422038  -1.2246659 ]
 [-1.3341405   0.29749313 -0.52242285 -0.4041265 ]
 [-1.5252156   0.2557124  -0.53259885 -0.8858886 ]
 [-1.4253224   0.128187   -0.57286394 -0.81729233]
 [-1.4296538   0.24398047 -0.5827315  -0.71319354]
 [-1.4786267   0.2762586  -0.56631696 -0.67880225]
 [-1.8319433  -0.4876786  -0.8659487  -1.0666666 ]
 [-1.5442852   0.26752904 -0.516761   -0.8628958 ]
 [-1.3392221   0.3055501  -0.4862369  -0.4094121 ]
 [-1.570642   -0.55673635 -0.8989723  -1.3439178 ]
 [-1.4705392  -0.35118

(128, 4)
[[-1.37166882e+00  3.74915212e-01 -6.79649234e-01 -1.33331501e+00]
 [-1.31266642e+00 -2.99802393e-01 -9.82381701e-01 -2.26623940e+00]
 [-1.29310095e+00 -2.37763420e-01 -7.23452747e-01 -1.58106649e+00]
 [-1.72516978e+00 -3.16018790e-01 -1.08176982e+00 -1.46742487e+00]
 [-1.32064998e+00  3.79668742e-01 -6.05349779e-01 -1.62348044e+00]
 [-1.66722941e+00 -3.23080927e-01 -7.34690547e-01 -1.60770154e+00]
 [-1.28232932e+00 -1.46305934e-01 -8.73283327e-01 -2.07788873e+00]
 [-1.67344987e+00 -2.95628279e-01 -1.05365777e+00 -1.72123551e+00]
 [-1.45355940e+00  4.57725972e-01 -6.63458288e-01 -1.41172016e+00]
 [-1.55422807e+00 -3.05332631e-01 -7.14137018e-01 -1.64709032e+00]
 [-1.41345072e+00  4.05550927e-01 -6.45224512e-01 -1.55962312e+00]
 [-1.41947174e+00  4.12352294e-01 -6.28223419e-01 -1.55503488e+00]
 [-2.52320671e+00 -1.86143982e+00 -2.05469775e+00 -4.90628052e+00]
 [-1.27183115e+00 -1.01768330e-01 -8.51165533e-01 -2.00586677e+00]
 [-1.31266642e+00 -2.99802393e-01 -9.82381701e-01 -2.

(128, 4)
[[-1.54207504e+00 -9.03212950e-02 -1.05920744e+00 -2.40960717e+00]
 [-1.03765070e+00  4.22751278e-01 -7.48144686e-01 -1.39585364e+00]
 [-1.13160503e+00  1.58000663e-02 -9.73326862e-01 -2.55016494e+00]
 [-1.26382828e+00  2.86926717e-01 -8.38101327e-01 -1.95281792e+00]
 [-1.29253590e+00  3.07294965e-01 -8.36225569e-01 -1.96092188e+00]
 [-1.09079492e+00  2.96063453e-01 -8.33277524e-01 -1.81612706e+00]
 [-1.25331318e+00 -2.32972383e-01 -1.09745073e+00 -2.86194301e+00]
 [-1.79921317e+00 -1.44622564e-01 -1.15871811e+00 -1.97746348e+00]
 [-1.36535025e+00  4.47050929e-01 -8.08533251e-01 -1.98733616e+00]
 [-1.63907683e+00 -1.36985183e-01 -1.16086054e+00 -2.05103016e+00]
 [-1.61425304e+00 -1.35997653e-01 -1.15158725e+00 -2.32355523e+00]
 [-1.21038401e+00  1.70679033e-01 -7.96114802e-01 -2.44772100e+00]
 [-2.21811986e+00 -1.64170480e+00 -1.88602424e+00 -5.27275229e+00]
 [-1.02507830e+00  4.18076575e-01 -7.82641232e-01 -1.39634049e+00]
 [-1.54207504e+00 -9.03212950e-02 -1.05920744e+00 -2.

 [-1.3938035e+00  9.4242796e-02 -1.1936015e+00 -2.8710585e+00]]
(128, 4)
[[-2.93971753e+00 -1.70530736e+00 -2.55349112e+00 -6.29320574e+00]
 [-1.59393728e+00  3.59176159e-01 -1.11473656e+00 -2.39015627e+00]
 [-2.66787457e+00  2.16967568e-01 -7.55953491e-01 -1.79811108e+00]
 [-1.60338366e+00  2.45509192e-01 -1.05864239e+00 -2.80751634e+00]
 [-1.64698768e+00  1.57876357e-01 -1.01687193e+00 -2.67994308e+00]
 [-1.50675201e+00  3.81545305e-01 -1.10866392e+00 -2.43486214e+00]
 [-2.83240056e+00  2.23280177e-01 -8.02775800e-01 -1.89661407e+00]
 [-1.14857137e+00  3.51448983e-01 -1.02518916e+00 -1.83339143e+00]
 [-1.28767049e+00  2.17311248e-01 -1.14329028e+00 -2.33558869e+00]
 [-8.33945751e-01  3.39736432e-01 -7.82057941e-01 -2.71909499e+00]
 [-1.30531108e+00  2.68710822e-01 -1.13316655e+00 -2.19100404e+00]
 [-1.50027335e+00  4.26225990e-01 -1.07383466e+00 -2.62718868e+00]
 [-1.62882948e+00  3.00508767e-01 -1.05938792e+00 -2.83334017e+00]
 [-2.83240056e+00  2.23280177e-01 -8.02775800e-01 -1.896

(128, 4)
[[-1.15888476e+00  8.38341564e-02 -9.62859631e-01 -2.98155904e+00]
 [-2.00561428e+00 -1.98573358e-02 -1.46708775e+00 -3.36291051e+00]
 [-2.92478085e+00  7.83012658e-02 -1.51457155e+00 -2.73347902e+00]
 [-3.29646921e+00  1.49306014e-01 -1.37536895e+00 -2.42719269e+00]
 [-3.02322340e+00  2.02042773e-01 -1.04652405e+00 -2.67976928e+00]
 [-2.62999892e+00 -3.19879130e-03 -1.60756469e+00 -3.12020302e+00]
 [-1.85522616e+00  3.39018553e-01 -1.22979021e+00 -3.08634734e+00]
 [-1.81932902e+00  3.56086850e-01 -1.26261389e+00 -2.99270463e+00]
 [-1.77410758e+00  2.57689536e-01 -1.32448626e+00 -2.66903663e+00]
 [-3.61538100e+00  3.19577038e-01 -8.67198586e-01 -2.09451890e+00]
 [-1.18265188e+00  1.09961823e-01 -1.07865226e+00 -3.15526700e+00]
 [-7.05281019e-01  2.05852136e-01 -8.82412672e-01 -2.43558717e+00]
 [-1.66744030e+00  3.04502696e-01 -1.28962219e+00 -2.77784467e+00]
 [-1.85522616e+00  3.39018553e-01 -1.22979021e+00 -3.08634734e+00]
 [-1.71662736e+00  2.84697950e-01 -1.31027222e+00 -2.

(128, 4)
[[-3.29212070e+00  2.03404725e-01 -1.09860766e+00 -2.99595118e+00]
 [-1.42828202e+00 -1.92333341e-01 -1.23087478e+00 -3.63922691e+00]
 [-4.12264013e+00 -1.61982286e+00 -3.03480482e+00 -7.52647495e+00]
 [-3.85455465e+00  3.13618839e-01 -9.90974367e-01 -2.76957011e+00]
 [-1.13533449e+00 -1.23193868e-01 -9.09711480e-01 -2.78283405e+00]
 [-1.33172238e+00 -1.42650068e-01 -9.82621670e-01 -3.01006699e+00]
 [-1.53142345e+00 -3.01045775e-01 -1.38542223e+00 -3.81947279e+00]
 [-3.50316334e+00  1.10387020e-01 -1.57398129e+00 -3.11442733e+00]
 [-3.68436432e+00  2.55870044e-01 -1.06207395e+00 -2.94593239e+00]
 [-1.85688221e+00  2.65812486e-01 -1.33771932e+00 -2.97993279e+00]
 [-1.64693773e+00  5.42904697e-02 -1.43417645e+00 -2.82863975e+00]
 [-1.61607993e+00  1.20171361e-01 -1.39504349e+00 -2.61633372e+00]
 [-1.31896985e+00 -1.19680233e-01 -1.09288025e+00 -3.24394059e+00]
 [-1.34148884e+00 -1.50955558e-01 -9.99020040e-01 -3.02811623e+00]
 [-1.57792592e+00  8.00116584e-02 -1.41821122e+00 -2.

(128, 4)
[[-1.68154776e+00  1.16927922e-03 -1.50217104e+00 -2.79942870e+00]
 [-3.27170444e+00  1.98453382e-01 -1.11799264e+00 -3.08478355e+00]
 [-1.55573916e+00  7.03625828e-02 -1.40494180e+00 -2.52243757e+00]
 [-4.70592403e+00 -1.78228498e+00 -3.40273714e+00 -8.41095829e+00]
 [-2.89682460e+00  9.81177837e-02 -1.38524580e+00 -3.61290169e+00]
 [-3.39283013e+00  1.96231082e-01 -1.11059773e+00 -3.05951691e+00]
 [-4.00592136e+00  7.13589936e-02 -1.63761127e+00 -3.46194458e+00]
 [-1.67109859e+00 -5.33765435e-01 -1.32977259e+00 -3.61626911e+00]
 [-2.00194550e+00 -3.38965505e-02 -1.58440804e+00 -3.15900040e+00]
 [-1.65300453e+00 -2.24037021e-02 -1.51723921e+00 -2.83691740e+00]
 [-2.89682460e+00  9.81177837e-02 -1.38524580e+00 -3.61290169e+00]
 [-1.07334507e+00 -3.09604704e-01 -9.41709697e-01 -2.57469344e+00]
 [-2.89450240e+00  9.64396745e-02 -1.42817378e+00 -3.71193385e+00]
 [-1.48456037e+00 -4.47478831e-01 -1.17037499e+00 -3.22029090e+00]
 [-3.31286693e+00 -2.06272885e-01 -1.85999215e+00 -4.

(128, 4)
[[-1.6880395e+00 -1.9539525e-01 -1.6165111e+00 -3.0139802e+00]
 [-1.5032324e+00 -9.0681475e-01 -1.2911543e+00 -3.0804381e+00]
 [-3.9204328e+00  1.6933171e-01 -1.1165277e+00 -3.2454317e+00]
 [-2.3247161e+00 -2.2513349e-01 -1.6309628e+00 -3.8726568e+00]
 [-1.7420089e+00 -1.4752586e-01 -1.5916554e+00 -2.9122686e+00]
 [-2.2494035e+00  1.6449563e-02 -1.5082910e+00 -3.5436800e+00]
 [-1.4512942e+00 -8.5388672e-01 -1.1414957e+00 -2.9226146e+00]
 [-2.0456457e+00  1.1582068e-02 -1.4878050e+00 -3.3037953e+00]
 [-1.8374722e+00 -1.1538333e+00 -1.6382625e+00 -4.0137224e+00]
 [-1.7619332e+00 -1.1142217e+00 -1.5753909e+00 -3.8693597e+00]
 [-1.3343385e+00 -8.0211043e-01 -1.1183215e+00 -2.8427711e+00]
 [-2.2299860e+00  9.8004853e-03 -1.5184799e+00 -3.5008574e+00]
 [-5.3843703e+00 -2.4951863e+00 -4.2767658e+00 -9.9648075e+00]
 [-5.3843703e+00 -2.4951863e+00 -4.2767658e+00 -9.9648075e+00]
 [-1.4745691e+00 -8.6331433e-01 -1.1719633e+00 -2.9721346e+00]
 [-5.1727324e+00 -2.3040948e+00 -3.9790850e+00

KeyboardInterrupt: 

In [15]:
state = environment.reset()
step = 1
end_episode = False
reward_accumulated = 0
while not(end_episode):

    # Select an action for the current state
    action = agent.select(state)

    # Execute the action on the environment
    state_next, reward, terminal_state, info = environment.step(action)

    environment.render()
    if terminal_state or (state[6] == 1 and state[7] == 1):
        end_episode = True 
        reward_accumulated += reward
    else:
        state = state_next
        step += 1
        reward_accumulated += reward
print(reward_accumulated)
environment.close()

181.77185429499087
