In [1]:
from utils import *
import gym

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy

import math

class NoiseLinear(nn.Linear):
    def __init__(self, in_, out_, val = 0.017, bias = True):
        super(NoiseLinear, self).__init__(in_,out_,bias)
        self.sigma_weight = nn.Parameter(torch.full((out_, in_), val))
        self.register_buffer("eps_weight", torch.zeros(out_, in_))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_,), val))
            self.register_buffer("eps_bias", torch.zeros(out_))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(1 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x):
        self.eps_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.eps_bias.normal_()
            bias = bias + self.sigma_bias * self.eps_bias.data
        return F.linear(x, self.weight + self.sigma_weight * self.eps_weight, bias)

class targetNet(nn.Module):
    def __init__(self, off_net):
        super(targetNet, self).__init__()
        self.net = copy.deepcopy(off_net)
        self.off_net = off_net
        
    def alpha_update(self, alpha = 0.05):
        for off, tgt in zip(self.off_net.parameters(), self.net.parameters()):
            tgt.data.copy_(off.data*alpha + tgt.data*(1-alpha))
    
    def copy_off_net(self):
        self.net.load_state_dict(self.off_net.state_dict())
    
    def forward(self, *x):
        return self.net(*x)
        

class Actor(nn.Module):
    def __init__(self, in_, act_n, hidden=256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, act_n),
        )
        
    def get_action(self, act_v):
        return act_v
    
    def forward(self, x):
        return self.net(x)

class DistCritic(nn.Module):
    def __init__(self, in_, act_v, atom=51, hidden=256):
        super(DistCritic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU()
        )
        self.net_out = nn.Sequential(
            nn.Linear(hidden + act_n, hidden),
            nn.ReLU(),
            nn.Linear(hidden, atom)
        )
    
    def forward(self, obs, act):
        return self.net_out(torch.cat([self.net(obs), act], dim=1))

In [3]:
ACT_LR = 0.0001
CRT_LR = 0.0005

GAMMA = 0.99

In [4]:
V_MAX = 20
V_MIN = -20
ATOMS = 51
V_DIST = (V_MAX-V_MIN)/(ATOMS-1)

def transform_dist(dist, rew, gamma, unroll_n, done):
    support_start = V_MIN
    
    dist = dist.cpu().detach().numpy()
    ret = np.zeros_like(dist)
    rew = rew.cpu().numpy()
    unroll_n = unroll_n.cpu().numpy()
    done = done.cpu().numpy()
    for atom in range(ATOMS):
        support = support_start + atom * V_DIST
        next_support = rew + (gamma**unroll_n) * support
        next_support[next_support > V_MAX] = V_MAX
        next_support[next_support < V_MIN] = V_MIN

        indices = ((next_support - V_MIN) / V_DIST).squeeze()
        l = np.floor(indices).astype(np.int64)
        r = np.ceil(indices).astype(np.int64)

        eq = l==r
        ret[eq, l[eq]] += dist[eq, atom]
        neq = l!=r
        ret[neq, l[neq]] += dist[neq, atom] * (r - indices)[neq]
        ret[neq, r[neq]] += dist[neq, atom] * (indices - l)[neq]

        if done.any():
            ret[done] = 0.0
            next_support = rew[done]
            next_support[next_support > V_MAX] = V_MAX
            next_support[next_support < V_MIN] = V_MIN

            indices = ((next_support - V_MIN) / V_DIST).squeeze()
            l = np.floor(indices).astype(np.int64)
            r = np.ceil(indices).astype(np.int64)

            eq = l==r
            eq_done = done.copy()
            eq_done[done] = eq
            if eq_done.any():
                ret[eq_done, l[eq]] = 1.0

            neq = l!=r
            neq_done = done.copy()
            neq_done[done] = neq
            if neq_done.any():
                ret[neq_done, l[neq]] = (r - indices)[neq]
                ret[neq_done, r[neq]] = (indices - l)[neq]
            
    return torch.FloatTensor(ret).cuda()

In [5]:
env = gym.make("Pendulum-v0")
act_n = env.action_space.shape[0]
obs_n = env.observation_space.shape[0]

actor = Actor(obs_n, act_n).cuda()
actor_tgt = targetNet(actor)
actor_optim = optim.Adam(actor.parameters(), ACT_LR)

critic = DistCritic(obs_n, act_n, ATOMS).cuda()
critic_tgt = targetNet(critic)
critic_optim = optim.Adam(critic.parameters(), CRT_LR)

ST_SIZE = 50000
ST_INIT = 10000
BATCH = 512
storage = Replay(ST_SIZE, True)
noise = NoiseMaker(act_n, "ou", decay=True)
noise.param["decay"] = ST_SIZE

agent = Agent(env, actor, noise, 200, 1)
agent.set_n_step(2, GAMMA)

In [None]:
EPOCH = 2000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        if len(storage) < ST_INIT:
            continue
            
        sample, indices, weights = storage.sample(BATCH)
        weights_ = torch.FloatTensor(weights).cuda()
        obs, act_v, act, next_obs, rew, done, etc, unroll_n = list(zip(*sample))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        unroll_n_ = torch.FloatTensor(unroll_n).unsqueeze(1).cuda()

        #Critic update
        critic_optim.zero_grad()
        q_pred = critic(obs_, act_v_)
        
        q_next_prob = critic_tgt(next_obs_, actor_tgt(next_obs_))
        q_next = F.softmax(q_next_prob, dim=1)
        q_target = transform_dist(q_next, rew_, GAMMA, unroll_n_, done_)

        q_entropy = -F.log_softmax(q_pred, dim=1) * q_target
        q_entropy_sum = q_entropy.sum(dim=1)
        q_loss = (weights_ * q_entropy_sum).sum()
        q_loss.backward()
        critic_optim.step()

        storage.update_priorities(indices, q_entropy_sum.cpu().detach().numpy())

        #Actor update
        actor_optim.zero_grad()
        q_dist = critic(obs_,actor(obs_))
        q_v = -F.softmax(q_dist,dim=1) * torch.arange(V_MIN, V_MAX + V_DIST, V_DIST).cuda()
        q_v = q_v.mean(dim=1)

        actor_loss = q_v.mean()
        actor_loss.backward()
        actor_optim.step()

        #target update
        critic_tgt.alpha_update()
        actor_tgt.alpha_update()

    print()

0 -855.42269 
1 -1748.63874 
2 -850.76813 
3 -865.73192 
4 -1069.79105 
5 -1667.22458 
6 -1512.76431 
7 -1445.10003 
8 -1169.07088 
9 -1575.89961 
10 -986.86838 
11 -1491.19924 
12 -961.27757 
13 -1293.81928 
14 -968.86290 
15 -1278.34625 
16 -783.98711 
17 -1664.46951 
18 -1447.51600 
19 -753.97188 
20 -1731.16238 
21 -1089.48874 
22 -1470.78677 
23 -918.15091 
24 -1792.64299 
25 -1021.93554 
26 -1579.49300 
27 -1172.04712 
28 -1395.47010 
29 -1400.34475 
30 -1250.95015 
31 -1169.67890 
32 -1516.22624 
33 -1297.65744 
34 -1549.69517 
35 -1724.44230 
36 -1589.37821 
37 -1058.61556 
38 -926.60655 
39 -963.18824 
40 -1700.29890 
41 -991.90680 
42 -1148.20777 
43 -1138.64827 
44 -1326.73774 
45 -1826.94966 
46 -1677.00062 
47 -1692.87860 
48 -1107.36377 
49 -1077.69971 
50 -1602.51753 
51 -1159.66959 
52 -1121.87536 
53 -1581.67320 
54 -1471.61078 
55 -938.49375 
56 -1434.01179 
57 -1496.79673 
58 -1420.23193 
59 -1364.15904 
60 -1473.60496 
61 -921.22945 
62 -1066.13030 
63 -1492.45345 
