In [15]:
!apt-get install swig
!pip3 install box2d box2d-kengz

Reading package lists... Done
Building dependency tree       
Reading state information... Done
swig is already the newest version (3.0.12-1).
0 upgraded, 0 newly installed, 0 to remove and 25 not upgraded.


In [0]:
import gym
import torch
from collections import namedtuple

version = "1.0.0"

StepInfo = namedtuple("StepInfo", ("obs", "act_v", "noise_v", "act", "last_obs", "rew", "done", "etc", "n"))

class Agent:
    def __init__(self, env, actor, noise = None, max_step = None, device = None):
        self.env = env
        if max_step is not None:
            self.env._max_episode_steps = max_step
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")            
        self.actor = actor
        self.noise = noise
        self.device = device
        
        self.wait = -1
        self.interval = -1
        self.frame = None
        
        self.step_set = False
        self.n_step = 1
        self.gamma = 0.99
        self.scale_factor = 1
        self.bias=0
        
    def prepare(self, step_unroll, gamma, scale=1, bias=0):
        self.step_set = True
        self.n_step = step_unroll
        self.gamma= gamma
        self.scale_factor = scale
        self.bias=bias
        
    def set_renderer(self, rend_wait=0, rend_interval=1, frame = None):
        self.wait = rend_wait
        self.interval = rend_interval
        self.frame = frame
        
    def reset(self):
        self.env.close()
        self.env.reset()
        
    def render(self, epoch):
        if self.wait >= 0 and epoch < self.wait:
            return
        if self.interval >= 0 and epoch % self.interval == 0:
            rend = self.env.render("rgb_array")
            if self.frame is not None:
                self.frame.append(self.env.render("rgb_array"))
        
    def episode(self, epoch):
        assert self.step_set
        buffer = []
        self.obs = self.env.reset()
        self.render(epoch)

        total_rew = 0
        total_rew_scaled = 0
        count = 0
        while True:
            with torch.no_grad():
                act_v = self.actor(torch.FloatTensor([self.obs]).to(self.device)).cpu().squeeze(0).numpy()
                if self.noise is not None:
                    noise_v = act_v + self.noise.get_noise()
                else:
                    noise_v = act_v
                if self.env.action_space.shape:
                    noise_v = noise_v.clip(self.env.action_space.low, self.env.action_space.high)
                act = self.actor.get_action(noise_v)

            next_obs, rew, done, etc = self.env.step(act)
            obs = self.obs
            self.obs = next_obs
            count += 1

            total_rew_scaled += rew
            rew += self.bias
            rew /= self.scale_factor
            total_rew += rew

            self.render(epoch)

            buffer.append(StepInfo(obs, act_v, noise_v, act, next_obs, rew, done, etc, -1))
            if done:
                break
                
            if len(buffer) < self.n_step:
                continue
                
            yield self.unroll_step(buffer)
            buffer.pop(0)
            
        while len(buffer):
            yield self.unroll_step(buffer)
            buffer.pop(0)
        print("ep#%4d"%epoch, "count : %4d, scaled_rew : %03.5f, total_rew : %03.5f"%(count, total_rew, total_rew_scaled))
        return

    def unroll_step(self, buffer):
        assert len(buffer)
        
        rews = list(map(lambda b:b.rew, buffer))
        rews.reverse()
        rew_sum = 0

        for r in rews:
            rew_sum*=self.gamma
            rew_sum+=r
            
        done = buffer[-1].done if len(buffer) == self.n_step else True
        return StepInfo(buffer[0].obs, buffer[0].act_v, buffer[0].noise_v, buffer[0].act, buffer[-1].last_obs, rew_sum, done, buffer[0].etc, len(buffer))

import numpy as np
import math

class NoiseMaker():
    def __init__(self, action_size, n_type = None, param = None, decay = True):
        self.action_size = action_size
        self.state = np.zeros(action_size, dtype=np.float32)
        self.count = 0
        self.decay = decay
        if n_type is None:
            n_type = "normal"
        self.type = n_type
        
        if param is None:
            self.param = {
                "start": 0.9,
                "end":0.02,
                "decay": 2000
            }
            if n_type =="ou":
                self.param["ou_mu"] = 0.0
                self.param["ou_th"] = 0.15
                self.param["ou_sig"] = 0.2
        else:
            self.param = param
            
    def get_noise(self):
        eps = self.param["end"] + (self.param["start"] - self.param["end"]) * math.exp(-1*self.count/ self.param["decay"])
        
        noise = np.random.normal(size=self.action_size)
        if self.type == "ou":
            self.state += self.param["ou_th"] * (self.param["ou_mu"] - self.state) + self.param["ou_sig"] * noise
            noise = self.state
        if not self.decay:
            eps = 1
        self.count += 1
            
        return noise * eps

import collections
import random
import numpy as np
import math

class Replay:
    def __init__(self, size, prio = False, alph = 0.6, beta = 0.4):
        self.memory = collections.deque(maxlen = size)
        self.size = size
        self.priorities = collections.deque(maxlen = size)
        self.prio = prio
        self.alph = alph if prio else 0
        self.beta = beta if prio else 0
        self.count = 0
        
    def push(self, data):
        self.memory.append(data)
        self.count += 1
        if self.prio:
            max_prio = np.array(self.priorities).max() if len(self.priorities) else 1.
            self.priorities.append(max_prio)
        
    def prepare(self, env):
        pass
        
    def sample(self, size):
        if self.prio:
            probs = np.array(self.priorities, dtype=np.float32)
            min_prio = np.array(self.priorities).min()
            if min_prio < 1:
                probs /= min_prio
            probs = probs ** self.alph
            probs /= probs.sum()
        else:
            probs = np.ones(len(self),) / len(self)
        
        indices = np.random.choice(len(self), size, p=probs)
        sample = [self.memory[idx] for idx in indices]
        
        beta = 1. + (self.beta - 1.) * math.exp(-1 * self.count / self.size)
        weights = (len(self) * probs[indices]) ** (-beta)
        weights /= weights.sum()
        
        return sample, indices, weights
        
    def update_priorities(self, indices, prios):
        if not self.prio:
            return
        prios += 1e-8
        for idx, prio in zip(indices,prios):
            self.priorities[idx] = prio
    
    def __len__(self):
        return len(self.memory)

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy

import math

class NoiseLinear(nn.Linear):
    def __init__(self, in_, out_, val = 0.017, bias = True):
        super(NoiseLinear, self).__init__(in_,out_,bias)
        self.sigma_weight = nn.Parameter(torch.full((out_, in_), val))
        self.register_buffer("eps_weight", torch.zeros(out_, in_))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_,), val))
            self.register_buffer("eps_bias", torch.zeros(out_))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(1 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x):
        self.eps_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.eps_bias.normal_()
            bias = bias + self.sigma_bias * self.eps_bias.data
        return F.linear(x, self.weight + self.sigma_weight * self.eps_weight, bias)

class targetNet(nn.Module):
    def __init__(self, off_net):
        super(targetNet, self).__init__()
        self.net = copy.deepcopy(off_net)
        self.off_net = off_net
        
    def alpha_update(self, alpha = 0.05):
        for off, tgt in zip(self.off_net.parameters(), self.net.parameters()):
            tgt.data.copy_(off.data*alpha + tgt.data*(1-alpha))
    
    def copy_off_net(self):
        self.net.load_state_dict(self.off_net.state_dict())
    
    def forward(self, *x):
        return self.net(*x)
        

class Actor(nn.Module):
    def __init__(self, in_, act_n, action_provider, hidden=512):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), act_n),
            nn.Tanh()
        )
        self.action_provider = action_provider
        
    def get_action(self, act_v):
        return self.action_provider(act_v)
    
    def forward(self, x):
        return self.net(x)

class DistCritic(nn.Module):
    def __init__(self, in_, act_v, atom=51, hidden=512):
        super(DistCritic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU()
        )
        self.net_out = nn.Sequential(
            nn.Linear(hidden + act_n, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), atom)
        )
    
    def forward(self, obs, act):
        return self.net_out(torch.cat([self.net(obs), act], dim=1))

In [0]:
V_MAX = 10
V_MIN = -5
ATOMS = 51
V_DIST = (V_MAX-V_MIN)/(ATOMS-1)

def transform_dist(dist, rew, gamma, unroll_n, done):
    support_start = V_MIN
    
    dist = dist.cpu().detach().numpy()
    ret = np.zeros_like(dist)
    rew = rew.cpu().numpy()
    unroll_n = unroll_n.cpu().numpy()
    done = done.cpu().numpy()
    for atom in range(ATOMS):
        support = support_start + atom * V_DIST
        next_support = rew + (gamma**unroll_n) * support
        next_support[next_support > V_MAX] = V_MAX
        next_support[next_support < V_MIN] = V_MIN

        indices = ((next_support - V_MIN) / V_DIST).squeeze()
        l = np.floor(indices).astype(np.int64)
        r = np.ceil(indices).astype(np.int64)

        eq = l==r
        ret[eq, l[eq]] += dist[eq, atom]
        neq = l!=r
        ret[neq, l[neq]] += dist[neq, atom] * (r - indices)[neq]
        ret[neq, r[neq]] += dist[neq, atom] * (indices - l)[neq]

        if done.any():
            ret[done] = 0.0
            next_support = rew[done]
            next_support[next_support > V_MAX] = V_MAX
            next_support[next_support < V_MIN] = V_MIN

            indices = ((next_support - V_MIN) / V_DIST).squeeze()
            l = np.floor(indices).astype(np.int64)
            r = np.ceil(indices).astype(np.int64)

            eq = l==r
            eq_done = done.copy()
            eq_done[done] = eq
            if eq_done.any():
                ret[eq_done, l[eq]] = 1.0

            neq = l!=r
            neq_done = done.copy()
            neq_done[done] = neq
            if neq_done.any():
                ret[neq_done, l[neq]] = (r - indices)[neq]
                ret[neq_done, r[neq]] = (indices - l)[neq]
            
    return torch.FloatTensor(ret)

In [34]:
ACT_LR = 0.001
CRT_LR = 0.001
GAMMA = 0.99

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")     

env = gym.make("BipedalWalker-v3")
act_n = env.action_space.shape[0]
obs_n = env.observation_space.shape[0]

actor = Actor(obs_n, act_n, lambda x:x).to(device)
actor_tgt = targetNet(actor)
actor_optim = optim.Adam(actor.parameters(), ACT_LR)

critic = DistCritic(obs_n, act_n, ATOMS).to(device)
critic_tgt = targetNet(critic)
critic_optim = optim.Adam(critic.parameters(), CRT_LR)

ST_SIZE = 50000
ST_INIT = 10000
ST_DECAY = 5000
BATCH = 512

storage = Replay(ST_SIZE, True)
noise = NoiseMaker(act_n, "ou", decay=True)
noise.param["ou_sig"] = 0.6
noise.param["decay"] = ST_DECAY

agent = Agent(env, actor, noise, device=device)
agent.prepare(3, GAMMA, 30)



In [0]:
import time
EPOCH = 2000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        if len(storage) < ST_INIT:
            continue
            
        sample, indices, weights = storage.sample(BATCH)
        weights_ = torch.FloatTensor(weights).to(device)
        obs, act_v, noise_v, act, next_obs, rew, done, etc, unroll_n = list(zip(*sample))
        
        obs_ = torch.FloatTensor(obs).to(device)
        act_v_ = torch.FloatTensor(act_v).to(device)
        noise_v_ = torch.FloatTensor(noise_v).to(device)
        act_ = torch.LongTensor(act).unsqueeze(1).to(device)
        next_obs_ = torch.FloatTensor(next_obs).to(device)
        rew_ = torch.FloatTensor(rew).unsqueeze(1).to(device)
        done_ = torch.BoolTensor(done).to(device)
        unroll_n_ = torch.FloatTensor(unroll_n).unsqueeze(1).to(device)

        #Critic update
        critic_optim.zero_grad()
        q_pred = critic(obs_, noise_v_)
        
        q_next_prob = critic_tgt(next_obs_, actor_tgt(next_obs_))
        q_next = F.softmax(q_next_prob, dim=1)
        q_target = transform_dist(q_next, rew_, GAMMA, unroll_n_, done_).to(device)

        q_entropy = -F.log_softmax(q_pred, dim=1) * q_target
        q_entropy_sum = q_entropy.sum(dim=1)
        q_loss = (weights_ * q_entropy_sum).sum()
        q_loss.backward()
        critic_optim.step()

        storage.update_priorities(indices, q_entropy_sum.cpu().detach().numpy())

        #Actor update
        actor_optim.zero_grad()
        
        #history decaying
        st_decay = torch.exp_(-torch.FloatTensor(len(storage)-indices)/ST_DECAY).unsqueeze(1).to(device)
        act_avg = actor(obs_) * (1-st_decay) + act_v_ * st_decay
        q_dist = critic(obs_,act_avg)
        
        q_v = -F.softmax(q_dist,dim=1) * torch.arange(V_MIN, V_MAX + V_DIST - 1e-8, V_DIST).to(device)
        q_v = q_v.mean(dim=1)

        actor_loss = q_v.mean()
        actor_loss.backward()
        actor_optim.step()

        #target update
        critic_tgt.alpha_update()
        actor_tgt.alpha_update()

ep#   0 count :  122, scaled_rew : -3.95219, total_rew : -118.56573
ep#   1 count :   39, scaled_rew : -3.82891, total_rew : -114.86728
ep#   2 count :   65, scaled_rew : -3.46146, total_rew : -103.84365
ep#   3 count :  132, scaled_rew : -3.56875, total_rew : -107.06239
ep#   4 count :   79, scaled_rew : -3.77794, total_rew : -113.33821
ep#   5 count :  107, scaled_rew : -4.21066, total_rew : -126.31990
ep#   6 count :  110, scaled_rew : -3.33480, total_rew : -100.04394
ep#   7 count :   43, scaled_rew : -3.76773, total_rew : -113.03189
ep#   8 count :   96, scaled_rew : -3.85814, total_rew : -115.74435
ep#   9 count :   77, scaled_rew : -3.33190, total_rew : -99.95697
ep#  10 count :  122, scaled_rew : -3.66036, total_rew : -109.81073
ep#  11 count :   56, scaled_rew : -4.12780, total_rew : -123.83396
ep#  12 count :   55, scaled_rew : -3.86181, total_rew : -115.85424
ep#  13 count :   90, scaled_rew : -3.63319, total_rew : -108.99569
ep#  14 count :  103, scaled_rew : -3.62144, tota