# Q Learning

In [None]:
import numpy as np

class QLearner:
    def __init__(self, state_space_size, action_space_size, learning_rate=0.01, discount_factor=0.9):
        print(f"state space size: {state_space_size}")
        print(f"action space size: {action_space_size}")
        self.state_space_size = state_space_size
        self.action_space_size = action_space_size
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor

        self.weights = np.zeros((state_space_size, action_space_size))

    def select_action(self, state):
        q_values = np.dot(state, self.weights)
        return np.argmax(q_values)

    def update(self, state, action, reward, next_state):
        current_q_values = np.dot(state, self.weights)
        next_q_values = np.dot(next_state, self.weights)

        best_next_action = np.argmax(next_q_values)
        target = reward + self.discount_factor * next_q_values[best_next_action]
        self.weights[:, action] += self.learning_rate * (target - current_q_values[action]) * state

In [None]:
import gym
from gym import wrappers
import numpy as np
import matplotlib.pyplot as plt

env = gym.make('CartPole-v1')

train_step = 500

agent = QLearner(env.observation_space._shape[0], env.action_space.n, learning_rate=0.005)

for i in range(train_step):
    total_reward = 0
    state = env.reset()
    while True:

        action = env.action_space.sample()

        next_state, reward, done, _ = env.step(action)

        total_reward += reward
        agent.update(state, action, reward, next_state)

        state = next_state

        if done:
            break

    if i % (5) == 0:
        # print(agent.weights)
        state = env.reset()
        eval_rwrd = 0
        while True:
            action = agent.select_action(state)

            state, reward, done, _ = env.step(action)

            eval_rwrd += reward

            if done:
                break
        if eval_rwrd > 200:
            print(f'step: {i}')
            print(f'eval rewards: {eval_rwrd}')

env.close()

state space size: 4
action space size: 2
step: 60
eval rewards: 500.0
step: 65
eval rewards: 500.0
step: 70
eval rewards: 500.0
step: 75
eval rewards: 500.0
step: 80
eval rewards: 320.0
step: 90
eval rewards: 500.0
step: 100
eval rewards: 401.0
step: 105
eval rewards: 279.0
step: 110
eval rewards: 404.0
step: 125
eval rewards: 412.0
step: 160
eval rewards: 238.0
step: 185
eval rewards: 238.0
step: 200
eval rewards: 308.0
step: 215
eval rewards: 256.0
step: 275
eval rewards: 500.0
step: 295
eval rewards: 500.0
step: 300
eval rewards: 500.0
step: 315
eval rewards: 377.0
step: 320
eval rewards: 500.0
step: 325
eval rewards: 500.0
step: 330
eval rewards: 398.0
step: 335
eval rewards: 500.0
step: 340
eval rewards: 500.0
step: 345
eval rewards: 500.0
step: 350
eval rewards: 500.0
step: 355
eval rewards: 500.0
step: 360
eval rewards: 500.0
step: 365
eval rewards: 462.0
step: 410
eval rewards: 500.0
step: 415
eval rewards: 500.0
step: 420
eval rewards: 500.0
step: 425
eval rewards: 500.0
step:

In [None]:
env = gym.make('CartPole-v1')


for _ in range(10):
    state = env.reset()
    total_reward = 0
    while True:
        action = agent.select_action(state)

        state, reward, done, _ = env.step(action)

        total_reward += reward

        if done:
            break

    print(f'rewards: {total_reward}')

env = wrappers.RecordVideo(env, './video')
state = env.reset()
total_reward = 0
while True:
        env.render()

        action = agent.select_action(state)

        state, reward, done, _ = env.step(action)

        total_reward += reward

        if done:
            break

env.close()

rewards: 238.0
rewards: 231.0
rewards: 277.0
rewards: 200.0
rewards: 200.0
rewards: 264.0
rewards: 214.0
rewards: 208.0
rewards: 244.0
rewards: 221.0


  logger.warn(
  logger.deprecation(
If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


In [None]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open('/content/video/rl-video-episode-0.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)