In [1]:
import gym
from utils import *

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import math

class NoiseLinear(nn.Linear):
    def __init__(self, in_, out_, val = 0.01, bias = True):
        super(NoiseLinear, self).__init__(in_,out_,bias)
        self.sigma_weight = nn.Parameter(torch.full((out_, in_), val))
        self.register_buffer("eps_weight", torch.zeros(out_, in_))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_,), val))
            self.register_buffer("eps_bias", torch.zeros(out_))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x):
        self.eps_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.eps_bias.normal_()
            bias = bias + self.sigma_bias * self.eps_bias.data
        return F.linear(x, self.weight + self.sigma_weight * self.eps_weight, bias)

class Critic(nn.Module):
    def __init__(self, in_, out_, hidden_ = 256):
        super(Critic, self).__init__()
        self.v = nn.Sequential(
            NoiseLinear(in_, hidden_),
            nn.ReLU(),
            NoiseLinear(hidden_, 1)
        )
        
        self.ad = nn.Sequential(
            NoiseLinear(in_, hidden_),
            nn.ReLU(),
            NoiseLinear(hidden_, out_)
        )
        
    def get_action(self, act_v):
        return act_v.argmax()
    
    def forward(self, x):
        v = self.v(x)
        adv = self.ad(x)
        return v + (adv - adv.mean(dim=1, keepdim=True))

In [3]:
LR = 0.0005
GAMMA = 0.995
TARGET_UPDATE = 10

env = gym.make("CartPole-v1")
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

net = Critic(obs_n, act_n).cuda()
net_tgt = Critic(obs_n, act_n).cuda()
net_tgt.load_state_dict(net.state_dict())
opt = optim.Adam(net.parameters(), LR)
loss = nn.MSELoss()

agent = Agent(env, net)
agent.set_n_step(3, GAMMA)

ST_SIZE = 100000
ST_INIT = 1000
ST_SAMPLE = 512
storage = Replay(ST_SIZE, True)

In [4]:
EPOCH = 1000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        
        if len(storage) < ST_INIT:
            continue
        
        sample, indices, weights = storage.sample(ST_SAMPLE)
        weights = torch.FloatTensor(weights).unsqueeze(1).cuda()
        obs, act_v, act, next_obs, rew, done, etc, n = list(zip(*sample))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        n_ = torch.FloatTensor(n).unsqueeze(1).cuda()
        
        #QN
        # q_next = net_tgt(next_obs_).max(1)[0].unsqueeze(1)
        
        #Double QN
        act_opt = net(next_obs_).max(1)[1].unsqueeze(1)
        q_next = net_tgt(next_obs_).gather(1, act_opt)
        q_next[done_] = 0
        q = rew_ + (GAMMA**n_) * q_next.detach()
        q_pred = net(obs_).gather(1, act_)
        
        opt.zero_grad()
        loss_ = weights * (q_pred - q) ** 2
        loss_mean = loss_.mean()
        loss_mean.backward()
        opt.step()
        
        #Priority Update.
        storage.update_priorities(indices, loss_.squeeze().data.cpu().numpy())
    
    if epoch%TARGET_UPDATE == 0:
        net_tgt.load_state_dict(net.state_dict())
    print(epoch, i + 1)
        
agent.reset()

0 79
1 35
2 71
3 45
4 73
5 60
6 73
7 77
8 42
9 56
10 51
11 99
12 109
13 78
14 66
15 94
16 62
17 46
18 48
19 122
20 40
21 77
22 35
23 62
24 60
25 63
26 77
27 65
28 74
29 122
30 134
31 55
32 111
33 266
34 170
35 217
36 116
37 357
38 205
39 179
40 215
41 129
42 181
43 200
44 172
45 209
46 188
47 210
48 208
49 243
50 196
51 346
52 266
53 190
54 205
55 207
56 198
57 252
58 216
59 187
60 219
61 307
62 296
63 264
64 214
65 236
66 237
67 243
68 233
69 319
70 229
71 361
72 270
73 256
74 206
75 239
76 259
77 239
78 281
79 253
80 218
81 306
82 263
83 320
84 268
85 260
86 400
87 321
88 232
89 462
90 292
91 495
92 500
93 257
94 291
95 301
96 500
97 316
98 259
99 362
100 230
101 500
102 476
103 250
104 223
105 290
106 215
107 500
108 500
109 500
110 270
111 443
112 295
113 293
114 347
115 460
116 273
117 246
118 245
119 275
120 256
121 266
122 358
123 242


KeyboardInterrupt: 