# "Reinforcement Learning - Epsilon Greedy"
> "Reinforcement Learning Balancing Cartpole with Epsilon Greedy"

- author: Bhargav Lad
- toc: true 
- badges: true
- comments: true
- image: images/q_learning.gif
- categories: [ jupyter,Reinforcement-Learning,CartPole,q-learning,epsilon-greedy]

In [6]:
import sys
import random
import time
import gym
import numpy as np
from IPython import display
from base64 import b64decode

In [7]:
env = gym.make('CartPole-v1')

# discretizing continuous value

As the state of the environment has continous values we need to discretize them inorder to work with it. Here we set  different bin size for each variable.

The `get_discret_state` function will take in the continous values of the state and convert it to discret values based on our bining.

In [8]:
bin_sz = [2,2,11,5]
discreeting = []
for i in range(env.observation_space.high.shape[0]):
    bins = np.linspace(env.observation_space.low[i],env.observation_space.high[i],bin_sz[i])
    bins = np.sort(np.append(bins, [0]))
    discreeting.append(bins)

def get_discret_state(state):
    dis = []
    for i,n in enumerate(state):
        dis.append(np.digitize(n,discreeting[i])) # discritize based on bins
    return tuple([x for x in dis])

# Initialize Q table

We initialize the Q table with Random values

In [17]:
q_table = np.random.uniform(low=0,high=1,size=tuple([x.shape[0] for x in discreeting]+[env.action_space.n]))

# Parameters

These are the parameter values we set for our policy

In [9]:
num_episodes = 4*60000  # total episodes
steps_per_episode = 500
discount = 0.97  # discount

lr =1e-3   # learning rate
min_lr = 1e-4
max_lr = 1e-2
lr_decay = 0.6  # learning rate decay factor

explore_rate = 0.1  # exploration rate
max_explore = 1.0
min_explore = 0.1
decay = 0.03  # decay factor for exploration rate

# Epsilon Greedy Policy 

![](images/equation.png)

In [19]:
rewards =[]
for episode in range(num_episodes):
    state = get_discret_state(env.reset())
    done=False
    curr_reward=0

    for step in range(steps_per_episode):
        # print(state)
        explore_rate_threshold = np.random.uniform(0,1)
        
        # if greater than rate then exploit
        if explore_rate_threshold > explore_rate:
            action = np.argmax(q_table[state])
        else:
            action = env.action_space.sample()
            
        # step through with action
        new_state, reward, done, info = env.step(action)
        new_state = get_discret_state(new_state)
        
        # Update q table
        state_action = tuple(list(state)+[action])
        # print(new_state)
        # print(q_table[new_state])
        
        q_table[state_action] = (1-lr)*q_table[state_action] + lr*(reward + discount * np.max(q_table[new_state]))

        curr_reward+=reward
        state = new_state

        if done:
            break

    # update explore_rate and learning rate for next episode
    explore_rate = min_explore + (max_explore-min_explore)*np.exp(-decay*episode)
    lr  = max_lr + (max_lr -min_lr)*np.exp(-lr_decay*episode)

    print(f"Episode {episode} avg: {np.array(curr_reward).mean()}")

print("Done")

np.save("q_table_weights",q_table)

Episode 0 avg: 10.0
Episode 1 avg: 14.0
Episode 2 avg: 28.0
Episode 3 avg: 42.0
Episode 4 avg: 13.0
Episode 5 avg: 12.0
Episode 6 avg: 15.0
Episode 7 avg: 31.0
Episode 8 avg: 61.0
Episode 9 avg: 21.0
Episode 10 avg: 20.0
Episode 11 avg: 17.0
Episode 12 avg: 15.0
Episode 13 avg: 20.0
Episode 14 avg: 31.0
Episode 15 avg: 27.0
Episode 16 avg: 13.0
Episode 17 avg: 15.0
Episode 18 avg: 18.0
Episode 19 avg: 27.0
Episode 20 avg: 13.0
Episode 21 avg: 13.0
Episode 22 avg: 21.0
Episode 23 avg: 10.0
Episode 24 avg: 23.0
Episode 25 avg: 9.0
Episode 26 avg: 23.0
Episode 27 avg: 10.0
Episode 28 avg: 13.0
Episode 29 avg: 16.0
Episode 30 avg: 10.0
Episode 31 avg: 19.0
Episode 32 avg: 10.0
Episode 33 avg: 10.0
Episode 34 avg: 20.0
Episode 35 avg: 14.0
Episode 36 avg: 44.0
Episode 37 avg: 18.0
Episode 38 avg: 14.0
Episode 39 avg: 18.0
Episode 40 avg: 13.0
Episode 41 avg: 16.0
Episode 42 avg: 13.0
Episode 43 avg: 37.0
Episode 44 avg: 12.0
Episode 45 avg: 10.0
Episode 46 avg: 12.0
Episode 47 avg: 17.0
Epi

Episode 402 avg: 11.0
Episode 403 avg: 10.0
Episode 404 avg: 16.0
Episode 405 avg: 11.0
Episode 406 avg: 10.0
Episode 407 avg: 10.0
Episode 408 avg: 9.0
Episode 409 avg: 11.0
Episode 410 avg: 50.0
Episode 411 avg: 14.0
Episode 412 avg: 9.0
Episode 413 avg: 17.0
Episode 414 avg: 10.0
Episode 415 avg: 13.0
Episode 416 avg: 21.0
Episode 417 avg: 11.0
Episode 418 avg: 23.0
Episode 419 avg: 11.0
Episode 420 avg: 45.0
Episode 421 avg: 14.0
Episode 422 avg: 15.0
Episode 423 avg: 11.0
Episode 424 avg: 10.0
Episode 425 avg: 10.0
Episode 426 avg: 15.0
Episode 427 avg: 19.0
Episode 428 avg: 50.0
Episode 429 avg: 12.0
Episode 430 avg: 23.0
Episode 431 avg: 11.0
Episode 432 avg: 11.0
Episode 433 avg: 11.0
Episode 434 avg: 9.0
Episode 435 avg: 12.0
Episode 436 avg: 11.0
Episode 437 avg: 20.0
Episode 438 avg: 17.0
Episode 439 avg: 11.0
Episode 440 avg: 109.0
Episode 441 avg: 9.0
Episode 442 avg: 9.0
Episode 443 avg: 10.0
Episode 444 avg: 16.0
Episode 445 avg: 11.0
Episode 446 avg: 39.0
Episode 447 av

KeyboardInterrupt: 

# Testing with final learned q table values

In [20]:
render = True
q_table = np.load('./q_table_weights_final.npy')
final_reward=[]
while True:
    done = False
    state = get_discret_state(env.reset())
    cum_reward=0
    for step in range(steps_per_episode):
        if render : env.render()

        action = np.argmax(q_table[state])
        new_state,reward,done,_ = env.step(action)
        new_state = get_discret_state(new_state)
        cum_reward+=reward

        if done:

            if render : env.render()
            if cum_reward>=200:
                print("Goal Reached!!",cum_reward)
                final_reward.append(1)

            else:
                print("Failed !! ",cum_reward)
                final_reward.append(0)

            break
        state = new_state
    env.close()
    if cum_reward >=200:
        break


Goal Reached!! 256.0


In [None]:
# Final Result

![](images/q_learning.gif)