In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym

In [2]:
# hyper parameters
batch_size = 32
lr = 0.01
ϵ = 0.9
γ = 0.9
target_replace_iter = 100
memory_capacity = 2000
env = gym.make("CartPole-v0").unwrapped
nA = env.action_space.n
nS = env.observation_space.shape[0]

[2017-06-20 22:39:00,544] Making new env: CartPole-v0


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(nS, 10)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(10, nA)
        self.out.weight.data.normal_(0, 0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        return self.out(x)  # action values @ S, i.e., Q[s, :]

In [15]:
class DQN(object):
    def __init__(self):
        self.eval_net, self.target_net = Net(), Net()
        self.learn_step_counter = 0
        self.memory_counter = 0
        self.memory = np.zeros((memory_capacity, nS * 2 + 2))
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=lr)
        self.loss_func = nn.MSELoss()

    def act(self, x):
        x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))
        if np.random.uniform() < ϵ:
            action_value = self.eval_net.forward(x)
            action = torch.max(action_value, 1)[1].data.numpy()[0, 0]
        else:
            action = np.random.randint(0, nA)
        return action

    def store_transition(self, s, a, r, s_):
        transition = np.hstack((s, [a, r], s_))
        index = self.memory_counter % memory_capacity
        self.memory[index, :] = transition
        self.memory_counter += 1

    def learn(self):
        if self.learn_step_counter % target_replace_iter == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1

        sample_index = np.random.choice(memory_capacity, batch_size)
        b_memory = self.memory[sample_index, :]
        b_s = Variable(torch.FloatTensor(b_memory[:, :nS]))
        b_a = Variable(torch.LongTensor(b_memory[:, nS:nS+1].astype(int)))
        b_r = Variable(torch.FloatTensor(b_memory[:, nS+1:nS+2]))
        b_sp = Variable(torch.FloatTensor(b_memory[:, -nS:]))

        # evaluation
        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)
        q_next = self.target_net(b_sp).detach()  # detach from graph (not backpropagate)
        q_target = b_r + γ * q_next.max(1)[0]
        loss = self.loss_func(q_eval, q_target)

        # update by optimizer
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [None]:
dqn = DQN()

print("Collecting experience...")
for i_episode in range(400):
    s = env.reset()
    ep_r = 0
    while True:
        # env.render()
        a = dqn.act(s)
        s_, r, done, info = env.step(a)
        x, x_dot, theta, theta_dot = s_

        # reward modify
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
        r = r1 + r2
        dqn.store_transition(s, a, r, s_)
        ep_r += r

        if dqn.memory_counter > memory_capacity:
            dqn.learn()
            if done:
                print("Ep: ", i_episode, "| Ep_r: ", round(ep_r, 2))

        if done:
            break
        s = s_

Collecting experience...
Ep:  202 | Ep_r:  2.77
Ep:  203 | Ep_r:  1.59
Ep:  204 | Ep_r:  1.76
Ep:  205 | Ep_r:  2.29
Ep:  206 | Ep_r:  2.64
Ep:  207 | Ep_r:  1.28
Ep:  208 | Ep_r:  4.02
Ep:  209 | Ep_r:  3.23
Ep:  210 | Ep_r:  2.33
Ep:  211 | Ep_r:  2.76
Ep:  212 | Ep_r:  1.66
Ep:  213 | Ep_r:  2.47
Ep:  214 | Ep_r:  5.73
Ep:  215 | Ep_r:  2.25
Ep:  216 | Ep_r:  1.35
Ep:  217 | Ep_r:  2.55
Ep:  218 | Ep_r:  4.07
Ep:  219 | Ep_r:  14.36
Ep:  220 | Ep_r:  2.62
Ep:  221 | Ep_r:  4.41
Ep:  222 | Ep_r:  1.68
Ep:  223 | Ep_r:  3.63
Ep:  224 | Ep_r:  2.08
Ep:  225 | Ep_r:  6.01
Ep:  226 | Ep_r:  11.04
Ep:  227 | Ep_r:  2.4
Ep:  228 | Ep_r:  10.67
Ep:  229 | Ep_r:  1.79
Ep:  230 | Ep_r:  8.12
Ep:  231 | Ep_r:  1.01
Ep:  232 | Ep_r:  0.65
Ep:  233 | Ep_r:  3.02
Ep:  234 | Ep_r:  4.02
Ep:  235 | Ep_r:  27.37
Ep:  236 | Ep_r:  16.13
Ep:  237 | Ep_r:  21.01
Ep:  238 | Ep_r:  6.53
Ep:  239 | Ep_r:  1.36
Ep:  240 | Ep_r:  15.97
Ep:  241 | Ep_r:  14.38
Ep:  242 | Ep_r:  15.15
Ep:  243 | Ep_r:  20.74
