In [0]:
import torch
import torch.nn as nn
import random
from collections import deque
import numpy as np
import torch.optim as optim
import gym

In [0]:
def cartpole_model (observation_space, action_space):
  return nn.Sequential (
      nn.Linear(observation_space, 24),
      nn.ReLU(),
      nn.Linear(24,24),
      nn.ReLU(),
      nn.Linear(24, action_space)
  )


In [0]:
class DQN:
  def __init__(self, observation_space, action_space):
    self.exploration_rate = MAX_EXPLORE
    self.action_space = action_space
    self.observation_space = observation_space
    self.memory = deque(maxlen=MEMORY_LEN)

    #define target net and policy net
    self.target_net = cartpole_model(self.observation_space, self.action_space)
    self.policy_net = cartpole_model(self.observation_space, self.action_space)

    #Copy weights
    self.target_net.load_state_dict(self.policy_net.state_dict())
    self.target_net.eval()

    # Define Loss Function, Optimizer and Limit Flag
    self.criterion = nn.MSELoss()
    self.optimizer = optim.Adam(self.policy_net.parameters())

    self.explore_limit = False
  
  # define Load Memory method
  def load_memory(self, state, action, reward, next_state, terminal):
    self.memory.append(( state, action, reward, next_state, terminal ))

  def predict_action(self, state):
    random_number = np.random.rand()
    if random_number < self.exploration_rate:
      return random.randrange(self.action_space)

    q_values = self.target_net(state).detach().numpy()

    return np.argmax(q_values[0])

  def experience_replay(self):
    if len(self.memory) < BATCH_SIZE:
      return 

    batch = random.sample(self.memory, BATCH_SIZE)

    #Update the q-values in each batch
    for state, action, reward, next_state, terminal in batch:
      q_update = reward
      if not terminal:
        q_update = reward + GAMMA  * self.target_net(next_state).max(axis=1)[0]
        q_values = self.target_net(state)
        q_values[0][action] = q_update


        # Update Weights and loss

        loss = self.criterion(self.policy_net(state), q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update the exploration rate
        if not self.explore_limit: 
          self.exploration_rate *= EXPLORE_DECAY
          if self.exploration_rate < MIN_EXPLORE:
            self.exploration_rate = MIN_EXPLORE
            self.explore_limit = True

In [0]:
ENV_NAME = "CartPole-v1"
BATCH_SIZE = 20
GAMMA = 0.95
LEARNING_RATE = 0.001
MAX_EXPLORE = 1.0
MIN_EXPLORE = 0.01
EXPLORE_DECAY = 0.995
MEMORY_LEN = 1_000_000
UPDATE_FREQ = 10

In [0]:
env = gym.make(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn = DQN(observation_space, action_space)

In [0]:
for i in range(100):
  state = env.reset()
  state = np.reshape(state, [1, observation_space])
  state = torch.from_numpy(state).float()
  score = 0
  while True:
    score += 1
    action = dqn.predict_action(state)
    next_state, reward, terminal, info = env.step(action)
    next_state = torch.from_numpy(np.reshape(next_state, [1,observation_space])).float()
    dqn.load_memory(state, action, reward, next_state,terminal)
    state = next_state
    if terminal:
      print(f'| {i+1:02} |{dqn.exploration_rate:.4f} | {score:03} |')
      break

    dqn.experience_replay()

  if i%UPDATE_FREQ == 0:
    dqn.target_net.load_state_dict(dqn.policy_net.state_dict())

| 01 |1.0000 | 020 |
| 02 |0.0100 | 169 |
| 03 |0.0100 | 158 |
| 04 |0.0100 | 159 |
| 05 |0.0100 | 165 |
| 06 |0.0100 | 155 |
| 07 |0.0100 | 162 |
| 08 |0.0100 | 162 |
| 09 |0.0100 | 154 |
| 10 |0.0100 | 169 |
| 11 |0.0100 | 150 |
| 12 |0.0100 | 153 |
| 13 |0.0100 | 164 |
| 14 |0.0100 | 175 |
| 15 |0.0100 | 166 |
| 16 |0.0100 | 152 |
| 17 |0.0100 | 169 |
| 18 |0.0100 | 160 |
| 19 |0.0100 | 178 |
| 20 |0.0100 | 155 |
| 21 |0.0100 | 016 |
| 22 |0.0100 | 157 |
| 23 |0.0100 | 164 |
| 24 |0.0100 | 154 |
| 25 |0.0100 | 165 |
| 26 |0.0100 | 162 |
| 27 |0.0100 | 165 |
| 28 |0.0100 | 171 |
| 29 |0.0100 | 159 |
| 30 |0.0100 | 175 |
| 31 |0.0100 | 163 |
| 32 |0.0100 | 169 |
| 33 |0.0100 | 155 |
| 34 |0.0100 | 176 |
| 35 |0.0100 | 163 |
| 36 |0.0100 | 180 |
| 37 |0.0100 | 158 |
| 38 |0.0100 | 166 |
| 39 |0.0100 | 157 |
| 40 |0.0100 | 177 |
| 41 |0.0100 | 173 |
| 42 |0.0100 | 164 |
| 43 |0.0100 | 160 |
| 44 |0.0100 | 159 |
| 45 |0.0100 | 159 |
| 46 |0.0100 | 168 |
| 47 |0.0100 | 168 |
| 48 |0.0100 

In [0]:
def play_agent(dqn, env):
  observation = env.reset()
  total_reward=0
  for _ in range(500):
    env.render()
    observation = torch.tensor(observation).type('torch.FloatTensor').view(1,-1)
    q_values = dqn.target_net(observation).detach().numpy()
    action = np.argmax(q_values[0])
    new_observation, reward, done, _ = env.step(action)
    total_reward += reward
    observation = new_observation
    if(done):
      break

  env.close()
  print("Rewards: ",total_reward)

In [0]:
play_agent(dqn, env)

NoSuchDisplayException: ignored

In [0]:
import random


In [0]:
random.random()

0.28017089325146627