In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical
import gym
import random
import collections

In [41]:
env = gym.make('CartPole-v1')

In [43]:
# hyper parameters
EPSILON = 1
EPISODE = 2000
GAMMA = .98
ALPHA = .001
Q_TARG_PERIOD = 10

In [52]:
class DQN_Net(nn.Module):
    def __init__(self):
        super(DQN_Net, self).__init__()
        self.fc1 = nn.Linear(4, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, 2)
        
    def Q(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        #x = F.softmax(self.fc3(x), dim=0)
        #x = F.softmax(self.fc3(x), dim=0)
        x = self.fc3(x)
        return x

In [53]:
def train(q, q_tar, optimizer):
    
    # buffer 에서 랜덤으로 데이터 뽑기
    batch = random.sample(buffer, 32)
    s_buf, a_buf, r_buf, s2_buf, d_buf = [], [], [], [], []
    
    # 학습리스트에 데이터 할당
    for transition in batch:
        s, a, r, s2, d = transition
        s_buf.append(s)
        a_buf.append([a])
        r_buf.append([-100]) if d else r_buf.append([r])
        s2_buf.append(s2)
        d_buf.append([d])
    
    s_buf = torch.tensor(s_buf, dtype=torch.float)
    a_buf = torch.tensor(a_buf)
    r_buf = torch.tensor(r_buf)
    s2_buf = torch.tensor(s2_buf, dtype=torch.float)
    d_buf = torch.tensor(d_buf)

    # Q 계산
    Q = q.Q(s_buf)
    Q = Q.gather(1, a_buf) # a_buf 를 index 로 취급하여 Q 의 값을 추려낸다.
    
    # target Q 계산
    max_Q = q_tar.Q(s2_buf).max(1)[0].unsqueeze(1) # 차원 줄이거나 늘리기. view 함수도 차원변환함
    Q_targ = r_buf + GAMMA * max_Q
    
    # double DQN 업데이트
    #a = q_net.Q(s2_buf).max(1)[0].unsqueeze(1)
    #double_q = q_tar.Q(s2_buf).gather(1, a_buf)
    #Q_targ = r_buf + GAMMA * double_q
    
    # loss
    # mse 대신 huber loss 사용 - 덜 민감해서 급격한 변화 방지
    # It is less sensitive to outliers than the MSELoss and in some cases prevents exploding gradients 
    loss = F.smooth_l1_loss(Q, Q_targ)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [54]:
buffer = collections.deque(maxlen = 50000)
ep = 1

# network 생성
q_net = DQN_Net()
q_targ = DQN_Net()

# target net 에 train net
q_targ.load_state_dict(q_net.state_dict())

# env 초기화
state = env.reset()

# optimizer
optimizer = optim.Adam(q_net.parameters(), ALPHA)

while(ep < EPISODE):
    done = False
    total_reward = 0
    #EPSILON = max(0.01, 0.08 - 0.01*(ep/200))
    
    while(not done):    
        # Q value 뽑기
        Q_value = q_net.Q(torch.from_numpy(state).float())

        # action 선택
        if(random.random() < EPSILON):
            action = env.action_space.sample()
        else:
            action = Q_value.argmax().item()

        # step 진행
        state_next, reward, done, _ = env.step(action)
        
        # reward 합산
        total_reward += reward
        
        # buffer 에 data stack
        buffer.append((state, action, reward, state_next, done))
        
        # state 갱신
        state = state_next
        
        # 학습
        if(len(buffer) > 2000):
            train(q_net, q_targ, optimizer)
        
        if(done):
            
            # periodical Update target net
            if ep % Q_TARG_PERIOD == 0:
                q_targ.load_state_dict(q_net.state_dict())
                
            print(ep, total_reward)
            ep += 1
            EPSILON = 1 / ((ep / 100) + 1)
            state = env.reset()

1 13.0
2 31.0
3 19.0
4 16.0
5 65.0
6 17.0
7 11.0
8 28.0
9 12.0
10 27.0
11 27.0
12 18.0
13 29.0
14 20.0
15 25.0
16 28.0
17 9.0
18 15.0
19 18.0
20 31.0
21 11.0
22 28.0
23 26.0
24 10.0
25 23.0
26 24.0
27 25.0
28 19.0
29 12.0
30 12.0
31 12.0
32 8.0
33 24.0
34 10.0
35 11.0
36 13.0
37 38.0
38 14.0
39 49.0
40 12.0
41 10.0
42 18.0
43 26.0
44 11.0
45 11.0
46 20.0
47 12.0
48 18.0
49 19.0
50 14.0
51 38.0
52 18.0
53 9.0
54 12.0
55 14.0
56 18.0
57 14.0
58 35.0
59 9.0
60 17.0
61 11.0
62 10.0
63 12.0
64 14.0
65 14.0
66 32.0
67 12.0
68 14.0
69 20.0
70 9.0
71 13.0
72 12.0
73 10.0
74 10.0
75 15.0
76 18.0
77 19.0
78 11.0
79 26.0
80 9.0
81 15.0
82 11.0
83 11.0
84 19.0
85 19.0
86 12.0
87 24.0
88 18.0
89 17.0
90 15.0
91 10.0
92 13.0
93 11.0
94 12.0
95 18.0
96 16.0
97 10.0
98 11.0
99 26.0
100 15.0
101 16.0
102 17.0
103 12.0
104 10.0
105 16.0
106 16.0
107 10.0
108 9.0
109 13.0
110 10.0
111 21.0
112 8.0
113 11.0
114 11.0
115 14.0
116 17.0
117 11.0
118 10.0
119 9.0
120 16.0
121 12.0
122 12.0
123 9.0
124 11.0
12

KeyboardInterrupt: 