<a href="https://colab.research.google.com/github/kerker83/DQN/blob/main/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import

In [None]:
import random
import gym
import numpy as np
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
import os

## Hyperparameter

In [None]:
env = gym.make('CartPole-v0')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
batch_size = 32
n_episodes = 1000

output_dir = 'model_output/cartpole/'
if not os.path.exists(output_dir):
  os.makedirs(output_dir)

## Agent Class

In [None]:
class DQNAgent:
  def __init__(self, state_size, action_size):
    self.state_size = state_size
    self.action_size = action_size
    self.memory = deque(maxlen=2000)
    self.gamma = 0.95
    self.epsilon = 1.0
    self.epsilon_decay = 0.995
    self.epsilon_min = 0.01
    self.learning_rate = 0.001
    self.model = self._build_model()

  def _build_model(self):
    model = Sequential()
    model.add(Dense(32, activation='relu', input_dim=self.state_size))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(self.action_size, activation='linear'))
    model.compile(loss='mse', optimizer=Adam(learning_rate=self.learning_rate))
    return model

  def remember(self, state, action, reward, next_state, done):
    self.memory.append((state, action, reward, next_state, done))

  def train(self, batch_size):
    minibatch = random.sample(self.memory, batch_size)
    for state, action, reward, next_state, done in minibatch:
      target = reward
      if not done:
        target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0]))
      target_f = self.model.predict(state)
      target_f[0][action] = target
      self.model.fit(state, target_f, epochs=1, verbose=0)
    if self.epsilon > self.epsilon_min:
      self.epsilon *= self.epsilon_decay

  def act(self, state):
    if np.random.rand() <= self.epsilon:
      return random.randrange(self.action_size)
    act_values = self.model.predict(state)
    return np.argmax(act_values[0])

  def save(self, name):
    self.model.save_weights(name)

  def load(self, name):
    self.model.load_weights(name)

## Starting Game And Training

In [None]:
agent = DQNAgent(state_size, action_size)

for e in range(n_episodes):
  state = env.reset()
  state = np.reshape(state, [1, state_size])
  done = False
  time = 0
  while not done:
    # env.render()
    action = agent.act(state)
    next_state, reward, done, _ = env.step(action)
    reward = reward if not done else -10
    next_state = np.reshape(next_state, [1, state_size])
    agent.remember(state, action, reward, next_state, done)
    state = next_state
    if done:
      print("episode: {}/{}, score: {}, e: {}".format(e, n_episodes-1, time, agent.epsilon))
    time += 1
  if len(agent.memory) > batch_size:
    agent.train(batch_size)

if e % 50 == 0:
  agent.save(output_dir + "weights_" + '{:04d}'.format(e) + ".hdf5")

episode: 0/999, score: 19, e: 1.0
episode: 1/999, score: 12, e: 1.0
episode: 2/999, score: 14, e: 0.995
episode: 3/999, score: 17, e: 0.990025
episode: 4/999, score: 29, e: 0.985074875
episode: 5/999, score: 36, e: 0.9801495006250001
episode: 6/999, score: 29, e: 0.9752487531218751
episode: 7/999, score: 9, e: 0.9703725093562657
episode: 8/999, score: 15, e: 0.9655206468094844
episode: 9/999, score: 11, e: 0.960693043575437
episode: 10/999, score: 10, e: 0.9558895783575597
episode: 11/999, score: 23, e: 0.9511101304657719
episode: 12/999, score: 23, e: 0.946354579813443
episode: 13/999, score: 11, e: 0.9416228069143757
episode: 14/999, score: 24, e: 0.9369146928798039
episode: 15/999, score: 35, e: 0.9322301194154049
episode: 16/999, score: 11, e: 0.9275689688183278
episode: 17/999, score: 15, e: 0.9229311239742362
episode: 18/999, score: 9, e: 0.918316468354365
episode: 19/999, score: 9, e: 0.9137248860125932
episode: 20/999, score: 14, e: 0.9091562615825302
episode: 21/999, score: 12