In [1]:
from utils import *
import gym

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy

import math

class NoiseLinear(nn.Linear):
    def __init__(self, in_, out_, val = 0.017, bias = True):
        super(NoiseLinear, self).__init__(in_,out_,bias)
        self.sigma_weight = nn.Parameter(torch.full((out_, in_), val))
        self.register_buffer("eps_weight", torch.zeros(out_, in_))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_,), val))
            self.register_buffer("eps_bias", torch.zeros(out_))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(1 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x):
        self.eps_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.eps_bias.normal_()
            bias = bias + self.sigma_bias * self.eps_bias.data
        return F.linear(x, self.weight + self.sigma_weight * self.eps_weight, bias)

class targetNet(nn.Module):
    def __init__(self, off_net):
        super(targetNet, self).__init__()
        self.net = copy.deepcopy(off_net)
        self.off_net = off_net
        
    def alpha_update(self, alpha = 0.05):
        for off, tgt in zip(self.off_net.parameters(), self.net.parameters()):
            tgt.data.copy_(off.data*alpha + tgt.data*(1-alpha))
    
    def copy_off_net(self):
        self.net.load_state_dict(self.off_net.state_dict())
    
    def forward(self, *x):
        return self.net(*x)
        

class Actor(nn.Module):
    def __init__(self, in_, act_n, action_provider, hidden=512):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), act_n),
            nn.Tanh()
        )
        self.action_provider = action_provider
        
    def get_action(self, act_v):
        return self.action_provider(act_v)
    
    def forward(self, x):
        return self.net(x)

class DistCritic(nn.Module):
    def __init__(self, in_, act_v, atom=51, hidden=512):
        super(DistCritic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU()
        )
        self.net_out = nn.Sequential(
            nn.Linear(hidden + act_n, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), atom)
        )
    
    def forward(self, obs, act):
        return self.net_out(torch.cat([self.net(obs), act], dim=1))

In [3]:
V_MAX = 10
V_MIN = -10
ATOMS = 51
V_DIST = (V_MAX-V_MIN)/(ATOMS-1)

def transform_dist(dist, rew, gamma, unroll_n, done):
    support_start = V_MIN
    
    dist = dist.cpu().detach().numpy()
    ret = np.zeros_like(dist)
    rew = rew.cpu().numpy()
    unroll_n = unroll_n.cpu().numpy()
    done = done.cpu().numpy()
    for atom in range(ATOMS):
        support = support_start + atom * V_DIST
        next_support = rew + (gamma**unroll_n) * support
        next_support[next_support > V_MAX] = V_MAX
        next_support[next_support < V_MIN] = V_MIN

        indices = ((next_support - V_MIN) / V_DIST).squeeze()
        l = np.floor(indices).astype(np.int64)
        r = np.ceil(indices).astype(np.int64)

        eq = l==r
        ret[eq, l[eq]] += dist[eq, atom]
        neq = l!=r
        ret[neq, l[neq]] += dist[neq, atom] * (r - indices)[neq]
        ret[neq, r[neq]] += dist[neq, atom] * (indices - l)[neq]

        if done.any():
            ret[done] = 0.0
            next_support = rew[done]
            next_support[next_support > V_MAX] = V_MAX
            next_support[next_support < V_MIN] = V_MIN

            indices = ((next_support - V_MIN) / V_DIST).squeeze()
            l = np.floor(indices).astype(np.int64)
            r = np.ceil(indices).astype(np.int64)

            eq = l==r
            eq_done = done.copy()
            eq_done[done] = eq
            if eq_done.any():
                ret[eq_done, l[eq]] = 1.0

            neq = l!=r
            neq_done = done.copy()
            neq_done[done] = neq
            if neq_done.any():
                ret[neq_done, l[neq]] = (r - indices)[neq]
                ret[neq_done, r[neq]] = (indices - l)[neq]
            
    return torch.FloatTensor(ret).cuda()

In [4]:
ACT_LR = 0.001
CRT_LR = 0.001
GAMMA = 0.99

env = gym.make("MountainCarContinuous-v0")
act_n = env.action_space.shape[0]
obs_n = env.observation_space.shape[0]

actor = Actor(obs_n, act_n, lambda x:x).cuda()
actor_tgt = targetNet(actor)
actor_optim = optim.Adam(actor.parameters(), ACT_LR)

critic = DistCritic(obs_n, act_n, ATOMS).cuda()
critic_tgt = targetNet(critic)
critic_optim = optim.Adam(critic.parameters(), CRT_LR)

ST_SIZE = 50000
ST_INIT = 10000
BATCH = 512
storage = Replay(ST_SIZE, True)
noise = NoiseMaker(act_n, "ou", decay=True)
noise.param["ou_sig"] = 0.6

agent = Agent(env, actor, noise, 200)
agent.prepare(2, GAMMA, 15, 0.4)
agent.set_renderer(10, 1)

In [None]:
import time
EPOCH = 2000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        if len(storage) < ST_INIT:
            continue
            
        sample, indices, weights = storage.sample(BATCH)
        weights_ = torch.FloatTensor(weights).cuda()
        obs, act_v, act, next_obs, rew, done, etc, unroll_n = list(zip(*sample))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        unroll_n_ = torch.FloatTensor(unroll_n).unsqueeze(1).cuda()

        #Critic update
        critic_optim.zero_grad()
        q_pred = critic(obs_, act_v_)
        
        q_next_prob = critic_tgt(next_obs_, actor_tgt(next_obs_))
        q_next = F.softmax(q_next_prob, dim=1)
        q_target = transform_dist(q_next, rew_, GAMMA, unroll_n_, done_)

        q_entropy = -F.log_softmax(q_pred, dim=1) * q_target
        q_entropy_sum = q_entropy.sum(dim=1)
        q_loss = (weights_ * q_entropy_sum).sum()
        q_loss.backward()
        critic_optim.step()

        storage.update_priorities(indices, q_entropy_sum.cpu().detach().numpy())

        #Actor update
        actor_optim.zero_grad()
        q_dist = critic(obs_,actor(obs_))
        q_v = -F.softmax(q_dist,dim=1) * torch.arange(V_MIN, V_MAX + V_DIST, V_DIST).cuda()
        q_v = q_v.mean(dim=1)

        actor_loss = q_v.mean()
        actor_loss.backward()
        actor_optim.step()

        #target update
        critic_tgt.alpha_update()
        actor_tgt.alpha_update()
    print()

0 -0.91627 
1 -1.11707 
2 -1.42747 
3 -1.85634 
4 -0.84389 
5 -1.89273 
6 -2.29308 
7 -0.31503 
8 -0.84702 
9 -1.30178 
10 -1.52799 
11 -2.38618 
12 -1.89704 
13 -1.75188 
14 -1.35804 
15 -1.45759 
16 -1.45772 
17 -1.55481 
18 -1.60103 
19 -1.53913 
20 -1.39499 
21 -1.34526 
22 -1.40310 
23 -1.53959 
24 -1.56184 
25 -1.52537 
26 -1.73302 
27 -1.59828 
28 -1.47762 
29 -1.65750 
30 -1.55487 
31 -1.64586 
32 -1.56533 
33 -1.64330 
34 -1.55276 
35 -1.61379 
36 -1.53489 
37 -1.55909 
38 -1.65481 
39 -1.59034 
40 -1.59098 
41 -1.63825 
42 -1.53064 
43 -1.58349 
44 -1.52386 
45 -1.54443 
46 -1.63313 
47 -1.61822 
48 -1.48371 
49 -1.65142 
50 1.44583 
51 1.38923 
52 1.25972 
53 1.40394 
54 1.75195 
55 1.59172 
56 1.49496 
57 1.31670 
58 1.71529 
59 1.62855 
60 1.43940 
61 1.42240 
62 1.48549 
63 1.51233 
64 1.52103 
65 1.26191 
66 1.44834 
67 1.73605 
68 1.39005 
69 1.27958 
70 1.61700 
71 1.88094 
72 1.23405 
73 1.35790 
74 1.93951 
75 1.50078 
76 1.83004 
77 1.99333 
78 1.43471 
79 1.28698 
