In [1]:
import gym
from utils import *

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import math

class NoiseLinear(nn.Linear):
    def __init__(self, in_, out_, val = 0.01, bias = True):
        super(NoiseLinear, self).__init__(in_,out_,bias)
        self.sigma_weight = nn.Parameter(torch.full((out_, in_), val))
        self.register_buffer("eps_weight", torch.zeros(out_, in_))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_,), val))
            self.register_buffer("eps_bias", torch.zeros(out_))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x):
        self.eps_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.eps_bias.normal_()
            bias = bias + self.sigma_bias * self.eps_bias.data
        return F.linear(x, self.weight + self.sigma_weight * self.eps_weight, bias)

class Critic(nn.Module):
    def __init__(self, in_, out_, hidden_ = 256):
        super(Critic, self).__init__()
        self.v = nn.Sequential(
            nn.Linear(in_, hidden_),
            nn.ReLU(),
            nn.Linear(hidden_, 1)
        )
        
        self.ad = nn.Sequential(
            nn.Linear(in_, hidden_),
            nn.ReLU(),
            nn.Linear(hidden_, out_)
        )
        
    def get_action(self, act_v):
        return act_v.argmax()
    
    def forward(self, x):
        v = self.v(x)
        adv = self.ad(x)
        return v + (adv - adv.mean(dim=1, keepdim=True))

In [3]:
LR = 0.005
GAMMA = 0.95
TARGET_UPDATE = 10

EPS_START = 0.9
EPS_END = 0.02
EPS_DECAY = 2000
counter = 0

env = gym.make("Acrobot-v1")
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

net = Critic(obs_n, act_n).cuda()
net_tgt = Critic(obs_n, act_n).cuda()
net_tgt.load_state_dict(net.state_dict())
opt = optim.Adam(net.parameters(), LR)
loss = nn.MSELoss()

agent = Agent(env, net)

ST_SIZE = 10000
ST_INIT = 1000
ST_SAMPLE = 256
storage = Replay(ST_SIZE)

In [None]:
EPOCH = 1000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        
        if len(storage) < ST_INIT:
            continue
        
        obs, act_v, act, next_obs, rew, done, _ = list(zip(*storage.sample(ST_SAMPLE)))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        
        #QN
        #q_next = net_tgt(next_obs_).max(1)[0].unsqueeze(1)
        
        #Double QN
        act_opt = net(next_obs_).max(1)[1].unsqueeze(1)
        q_next = net_tgt(next_obs_).gather(1, act_opt)
        q_next[done_] = 0
        q = rew_ + GAMMA * q_next
        q = q.detach()
        q_pred = net(obs_).gather(1, act_)
        
        opt.zero_grad()
        loss_ = loss(q_pred, q)
        loss_.backward()
        opt.step()
    
    if epoch%TARGET_UPDATE == 0:
        net_tgt.load_state_dict(net.state_dict())
    print(epoch, i + 1)
        
agent.reset()

0 500
1 500
2 147
3 124
4 229
5 500
6 467
7 281
8 500
9 500
10 500
11 500
12 500
13 500
14 500
15 500
16 500
17 500
18 500
19 500
20 500
21 500
22 500
23 500
24 500
25 500
26 500
27 500
28 500
29 500
30 500
31 500
32 500
33 500
34 500
35 500
36 500
37 500
38 500
39 500
40 500
41 500
42 500
43 500
44 500
45 500
46 500
47 500
48 500
49 500
50 500
51 500
52 500
53 500
54 500
55 500
56 500
57 500
58 500
59 500
60 500
61 500
62 500
63 500
64 500
65 500
66 500
67 500
68 500
69 500
70 500
71 500
72 500
73 500
74 500
75 500
76 500
77 500
78 500
79 500
80 500
81 500
82 500
83 500
84 500
85 500
86 500
87 500
88 500
89 500
90 500
91 500
92 500


In [None]:
a= torch.FloatTensor([1.])