In [1]:
import gym
from utils import *

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import math

class NoiseLinear(nn.Linear):
    def __init__(self, in_, out_, val = 0.017, bias = True):
        super(NoiseLinear, self).__init__(in_,out_,bias)
        self.sigma_weight = nn.Parameter(torch.full((out_, in_), val))
        self.register_buffer("eps_weight", torch.zeros(out_, in_))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_,), val))
            self.register_buffer("eps_bias", torch.zeros(out_))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(1 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x):
        self.eps_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.eps_bias.normal_()
            bias = bias + self.sigma_bias * self.eps_bias.data
        return F.linear(x, self.weight + self.sigma_weight * self.eps_weight, bias)

class Critic(nn.Module):
    def __init__(self, in_, out_, hidden_ = 256):
        super(Critic, self).__init__()
        self.v = nn.Sequential(
            NoiseLinear(in_, hidden_),
            nn.ReLU(),
            NoiseLinear(hidden_, 1)
        )
        
        self.ad = nn.Sequential(
            NoiseLinear(in_, hidden_),
            nn.ReLU(),
            NoiseLinear(hidden_, out_)
        )
        
    def get_action(self, act_v):
        return act_v.argmax()
    
    def forward(self, x):
        v = self.v(x)
        adv = self.ad(x)
        return v + (adv - adv.mean(dim=1, keepdim=True))

In [3]:
LR = 0.0001
GAMMA = 0.99
TARGET_UPDATE = 5

env = gym.make("CartPole-v1")
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

net = Critic(obs_n, act_n).cuda()
net_tgt = Critic(obs_n, act_n).cuda()
net_tgt.load_state_dict(net.state_dict())
opt = optim.Adam(net.parameters(), LR)
loss = nn.MSELoss()

agent = Agent(env, net)
agent.set_n_step(3, GAMMA)

ST_SIZE = 100000
ST_INIT = 10000
ST_SAMPLE = 512
storage = Replay(ST_SIZE, True)

In [None]:
EPOCH = 2000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        
        if len(storage) < ST_INIT:
            continue
        
        sample, indices, weights = storage.sample(ST_SAMPLE)
        weights = torch.FloatTensor(weights).unsqueeze(1).cuda()
        obs, act_v, act, next_obs, rew, done, etc, n = list(zip(*sample))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        n_ = torch.FloatTensor(n).unsqueeze(1).cuda()
        
        # QN
        # q_next = net_tgt(next_obs_).max(1)[0].unsqueeze(1)
        
        # Double QN
        act_opt = net(next_obs_).max(1)[1].unsqueeze(1)
        q_next = net_tgt(next_obs_).gather(1, act_opt)
        q_next[done_] = 0
        q = rew_ + (GAMMA**n_) * q_next.detach()
        q_pred = net(obs_).gather(1, act_)
        
        opt.zero_grad()
        loss_ = weights * (q_pred - q) ** 2
        loss_mean = loss_.mean()
        loss_mean.backward()
        opt.step()
        
        # Priority Update.
        storage.update_priorities(indices, loss_.squeeze().data.cpu().numpy())
    
    if epoch%TARGET_UPDATE == 0:
        net_tgt.load_state_dict(net.state_dict())
    print(epoch, i + 1)
        
agent.reset()

0 12
1 22
2 27
3 12
4 13
5 17
6 47
7 24
8 15
9 31
10 17
11 13
12 19
13 27
14 25
15 25
16 21
17 33
18 27
19 14
20 12
21 12
22 19
23 16
24 25
25 18
26 44
27 13
28 17
29 15
30 17
31 12
32 11
33 13
34 20
35 43
36 26
37 20
38 21
39 20
40 27
41 20
42 32
43 20
44 18
45 17
46 12
47 13
48 22
49 20
50 43
51 12
52 14
53 17
54 24
55 17
56 16
57 21
58 21
59 12
60 19
61 24
62 49
63 16
64 13
65 16
66 18
67 22
68 18
69 10
70 18
71 13
72 39
73 16
74 13
75 14
76 13
77 55
78 23
79 45
80 18
81 39
82 17
83 16
84 16
85 25
86 19
87 14
88 17
89 16
90 19
91 26
92 13
93 17
94 29
95 14
96 16
97 20
98 12
99 29
100 19
101 30
102 17
103 12
104 24
105 19
106 23
107 36
108 18
109 19
110 15
111 17
112 17
113 30
114 15
115 16
116 12
117 25
118 15
119 18
120 15
121 14
122 43
123 15
124 26
125 26
126 12
127 35
128 23
129 16
130 15
131 17
132 12
133 12
134 31
135 25
136 28
137 12
138 23
139 12
140 17
141 20
142 14
143 29
144 18
145 17
146 34
147 18
148 29
149 22
150 20
151 12
152 39
153 17
154 16
155 16
156 22
157 42
158 

1109 390
1110 348
1111 326
1112 324
1113 404
1114 388
1115 331
1116 393
1117 400
1118 338
1119 397
1120 396
1121 350
1122 347
1123 338
1124 359
1125 481
1126 394
1127 500
1128 469
1129 357
1130 417
1131 500
1132 479
1133 380
1134 466
1135 448
1136 490
1137 500
