In [1]:
import sys
import gym
import torch
import pylab
import random
import numpy as np
from collections import deque
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
from prioritized_memory import Memory

In [2]:
EPISODES = 500

# approximate Q function using Neural Network
# state is input and Q Value of each action is output of network
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_size, 24),
            nn.ReLU(),
            nn.Linear(24, 24),
            nn.ReLU(),
            nn.Linear(24, action_size)
        )

    def forward(self, x):
        return self.fc(x)

In [3]:
# DQN Agent for the Cartpole
# it uses Neural Network to approximate q function
# and prioritized experience replay memory & target q network
class DQNAgent():
    def __init__(self, state_size, action_size):
        # if you want to see Cartpole learning, then change to True
        self.render = False
        self.load_model = False

        # get size of state and action
        self.state_size = state_size
        self.action_size = action_size

        # These are hyper parameters for the DQN
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        self.memory_size = 20000
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.explore_step = 5000
        self.epsilon_decay = (self.epsilon - self.epsilon_min) / self.explore_step
        self.batch_size = 64
        self.train_start = 1000

        # create prioritized replay memory using SumTree
        self.memory = Memory(self.memory_size)

        # create main model and target model
        self.model = DQN(state_size, action_size)
        self.model.apply(self.weights_init)
        self.target_model = DQN(state_size, action_size)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.learning_rate)

        # initialize target model
        self.update_target_model()

        if self.load_model:
            self.model = torch.load('save_model/cartpole_dqn')

    # weight xavier initialize
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            torch.nn.init.xavier_uniform(m.weight)

    # after some time interval update the target model to be same with model
    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    # get action from model using epsilon-greedy policy
    def get_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        else:
            state = torch.from_numpy(state)
            state = Variable(state).float().cpu()
            q_value = self.model(state)
            _, action = torch.max(q_value, 1)
            return int(action)

    # save sample (error,<s,a,r,s'>) to the replay memory
    def append_sample(self, state, action, reward, next_state, done):
        target = self.model(Variable(torch.FloatTensor(state))).data
        old_val = target[0][action]
        target_val = self.target_model(Variable(torch.FloatTensor(next_state))).data
        if done:
            target[0][action] = reward
        else:
            target[0][action] = reward + self.discount_factor * torch.max(target_val)

        error = abs(old_val - target[0][action])

        self.memory.add(error, (state, action, reward, next_state, done))

    # pick samples from prioritized replay memory (with batch_size)
    def train_model(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon -= self.epsilon_decay

        mini_batch, idxs, is_weights = self.memory.sample(self.batch_size)
        mini_batch = np.array(mini_batch).transpose()

        states = np.vstack(mini_batch[0])
        actions = list(mini_batch[1])
        rewards = list(mini_batch[2])
        next_states = np.vstack(mini_batch[3])
        dones = mini_batch[4]

        # bool to binary
        dones = dones.astype(int)

        # Q function of current state
        states = torch.Tensor(states)
        states = Variable(states).float()
        pred = self.model(states)

        # one-hot encoding
        a = torch.LongTensor(actions).view(-1, 1)

        one_hot_action = torch.FloatTensor(self.batch_size, self.action_size).zero_()
        one_hot_action.scatter_(1, a, 1)

        pred = torch.sum(pred.mul(Variable(one_hot_action)), dim=1)

        # Q function of next state
        next_states = torch.Tensor(next_states)
        next_states = Variable(next_states).float()
        next_pred = self.target_model(next_states).data

        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)

        # Q Learning: get maximum Q value at s' from target model
        target = rewards + (1 - dones) * self.discount_factor * next_pred.max(1)[0]
        target = Variable(target)

        errors = torch.abs(pred - target).data.numpy()

        # update priority
        for i in range(self.batch_size):
            idx = idxs[i]
            self.memory.update(idx, errors[i])

        self.optimizer.zero_grad()

        # MSE Loss function
        loss = (torch.FloatTensor(is_weights) * F.mse_loss(pred, target)).mean()
        loss.backward()

        # and train
        self.optimizer.step()

In [4]:
    # In case of CartPole-v1, maximum length of episode is 500
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    model = DQN(state_size, action_size)

    agent = DQNAgent(state_size, action_size)
    scores, episodes = [], []

  result = entry_point.load(False)


In [5]:
import os
dirs = ['./save_graph', './save_model']
for i_dir in dirs:
    if not os.path.exists(i_dir):
        os.makedirs(i_dir)

In [6]:
bool_break = 0
for e in range(EPISODES):
    done = False
    score = 0

    state = env.reset()
    state = np.reshape(state, [1, state_size])

    while not done:
        if agent.render:
            env.render()

        # get action for the current state and go one step in environment
        action = agent.get_action(state)

        next_state, reward, done, info = env.step(action)
        next_state = np.reshape(next_state, [1, state_size])
        # if an action make the episode end, then gives penalty of -100
        reward = reward if not done or score == 499 else -10

        # save the sample <s, a, r, s'> to the replay memory
        agent.append_sample(state, action, reward, next_state, done)
        # every time step do the training
        if agent.memory.tree.n_entries >= agent.train_start:
            agent.train_model()

        score += reward
        state = next_state

        if done:
            # every episode update the target model to be same with model
            agent.update_target_model()

            # every episode, plot the play time
            score = score if score == 500 else score + 10
            scores.append(score)
            episodes.append(e)
            pylab.plot(episodes, scores, 'b')
            pylab.savefig("./save_graph/cartpole_dqn.png")
            print("episode:", e, "  score:", score, "  memory length:",
                  agent.memory.tree.n_entries, "  epsilon:", agent.epsilon)

            # if the mean of scores of last 10 episode is bigger than 490
            # stop training
            if np.mean(scores[-min(10, len(scores)):]) > 490:
                torch.save(agent.model, "./save_model/cartpole_dqn")
                # sys.exit()
                bool_break = 1
                break
    if bool_break:
        break


episode: 0   score: 16.0   memory length: 17   epsilon: 1.0
episode: 1   score: 13.0   memory length: 31   epsilon: 1.0
episode: 2   score: 17.0   memory length: 49   epsilon: 1.0
episode: 3   score: 12.0   memory length: 62   epsilon: 1.0
episode: 4   score: 23.0   memory length: 86   epsilon: 1.0
episode: 5   score: 18.0   memory length: 105   epsilon: 1.0
episode: 6   score: 22.0   memory length: 128   epsilon: 1.0
episode: 7   score: 18.0   memory length: 147   epsilon: 1.0
episode: 8   score: 38.0   memory length: 186   epsilon: 1.0
episode: 9   score: 17.0   memory length: 204   epsilon: 1.0
episode: 10   score: 9.0   memory length: 214   epsilon: 1.0
episode: 11   score: 19.0   memory length: 234   epsilon: 1.0
episode: 12   score: 22.0   memory length: 257   epsilon: 1.0
episode: 13   score: 21.0   memory length: 279   epsilon: 1.0
episode: 14   score: 20.0   memory length: 300   epsilon: 1.0
episode: 15   score: 18.0   memory length: 319   epsilon: 1.0
episode: 16   score: 13.

episode: 115   score: 106.0   memory length: 4008   epsilon: 0.40421799999993213
episode: 116   score: 128.0   memory length: 4137   epsilon: 0.37867599999993523
episode: 117   score: 192.0   memory length: 4330   epsilon: 0.34046199999993987
episode: 118   score: 116.0   memory length: 4447   epsilon: 0.3172959999999427
episode: 119   score: 92.0   memory length: 4540   epsilon: 0.2988819999999449
episode: 120   score: 132.0   memory length: 4673   epsilon: 0.2725479999999481
episode: 121   score: 94.0   memory length: 4768   epsilon: 0.2537379999999504
episode: 122   score: 143.0   memory length: 4912   epsilon: 0.22522599999995035
episode: 123   score: 117.0   memory length: 5030   epsilon: 0.20186199999994991
episode: 124   score: 116.0   memory length: 5147   epsilon: 0.17869599999994948
episode: 125   score: 160.0   memory length: 5308   epsilon: 0.14681799999994888
episode: 126   score: 120.0   memory length: 5429   epsilon: 0.12285999999994843
episode: 127   score: 132.0   memo

episode: 215   score: 128.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 216   score: 123.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 217   score: 177.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 218   score: 129.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 219   score: 123.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 220   score: 163.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 221   score: 168.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 222   score: 158.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 223   score: 186.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 224   score: 156.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 225   score: 141.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 226   score: 160.0   memory length: 20000   epsilon: 0.009999999999947773
epis

episode: 314   score: 198.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 315   score: 168.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 316   score: 195.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 317   score: 167.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 318   score: 144.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 319   score: 160.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 320   score: 318.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 321   score: 133.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 322   score: 187.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 323   score: 130.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 324   score: 159.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 325   score: 224.0   memory length: 20000   epsilon: 0.009999999999947773
epis

episode: 414   score: 17.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 415   score: 26.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 416   score: 39.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 417   score: 18.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 418   score: 131.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 419   score: 124.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 420   score: 141.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 421   score: 144.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 422   score: 164.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 423   score: 152.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 424   score: 160.0   memory length: 20000   epsilon: 0.009999999999947773
episode: 425   score: 500.0   memory length: 20000   epsilon: 0.009999999999947773
episode: