In [1]:
from utils import *
import gym

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy

class targetNet(nn.Module):
    def __init__(self, off_net):
        super(targetNet, self).__init__()
        self.net = copy.deepcopy(off_net)
        self.off_net = off_net
        
    def alpha_update(alpha = 0.05):
        for off, tgt in zip(self.off_net.parameters(), self.net.parameters()):
            tgt.data.copy_(off.data*alpha + tgt.data*(1-alpha))
    
    def copy_off_net():
        self.net.load_state_dict(self.off_net.state_dict())
    
    def forward(self, *x):
        return self.net(*x)
        

class Actor(nn.Module):
    def __init__(self, in_, act_n, hidden=256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, act_n),
        )
        
    def get_action(self, act_v):
        return act_v.argmax()
    
    def forward(self, x):
        return self.net(x)

class DistCritic(nn.Module):
    def __init__(self, in_, act_v, atom=51, hidden=256):
        super(DistCritic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU()
        )
        self.net_out = nn.Sequential(
            nn.Linear(hidden + act_n, hidden),
            nn.ReLU(),
            nn.Linear(hidden, atom)
        )
    
    def forward(self, obs, act):
        return self.net_out(torch.cat([self.net(obs), act], dim=1))

In [3]:
ACT_LR = 0.00025
CRT_LR = 0.001

GAMMA = 0.99

In [60]:
V_MAX = 10
V_MIN = -10
V_DIST = 0.4
ATOMS = 51

def transform_dist(dist, rew, gamma, unroll_n):
    support_start = V_MIN
    
    dist = dist.cpu().detach().numpy()
    unroll_n = unroll_n.cpu()
    ret = np.zeros_like(dist.shape)
    for n in unroll_n.unique().tolist():
        indices = (unroll_n == n).squeeze().numpy()
        for i in range(ATOMS):
            support = support_start + i * V_DIST
            print(rew[indices])
            reward = rew[indices][0].item()
            next_support = reward + (gamma**n) * support
            next_support = min(max(next_support, -10), 10)
            
            float_index = (next_support - V_MIN) / V_DIST

            floor = math.floor(float_index)
            ceil = math.ceil(float_index)
            
            if floor == ceil:
                ret[indices][:,floor] += dist[indices][:,i]
            else:
                print(indices.shape)
                ret[indices][:,floor] += dist[indices][:,i] * (ceil - float_index)
                ret[indices][:,ceil] += dist[indices][:,i]  * (float_index - floor)
            
    return torch.FloatTensor(ret).cuda()

In [61]:
transform_dist(q_next, rew_, GAMMA, unroll_n_)

tensor([[1.]], device='cuda:0')
(10,)


IndexError: boolean index did not match indexed array along dimension 0; dimension is 2 but corresponding boolean dimension is 10

In [5]:
V_MAX = 10
V_MIN = -10
V_DIST = 0.4
ATOMS = 51

def transform_dist(dist, rew, gamma, unroll_n):
    support_start = V_MIN
    
    ret = torch.zeros_like(dist)
    for n in unroll_n.unique().tolist():
        indices = (unroll_n == n).squeeze()
        for i in range(ATOMS):
            support = support_start + i * V_DIST
            next_support = rew[indices] + (gamma**n) * support
            next_support[next_support > 10] = 10
            next_support[next_support < -10] = -10
            
            float_index = (next_support - V_MIN) / V_DIST

            floor = math.floor(float_index[0])
            ceil = math.ceil(float_index[0])
            

            if floor == ceil:
                ret[indices][:,floor] += dist[indices][:,i]
            else:
                ret[indices][:,floor] += dist[indices][:,i] * (ceil - float_index).squeeze()
                ret[indices][:,ceil] += dist[indices][:,i]  * (float_index - floor).squeeze()
            
    return torch.FloatTensor(ret).cuda()

In [6]:
env = gym.make("CartPole-v1")
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

actor = Actor(obs_n, act_n).cuda()
actor_tgt = targetNet(actor)
actor_optim = optim.Adam(actor.parameters(), ACT_LR)

critic = DistCritic(obs_n, act_n).cuda()
critic_tgt = targetNet(critic)
critic_optim = optim.Adam(critic.parameters(), CRT_LR)

ST_SIZE = 100000
ST_INIT = 5000
BATCH = 512
storage = Replay(ST_SIZE, True)
noise = NoiseMaker(act_n, "ou", decay=True)

agent = Agent(env, actor, noise)
agent.set_n_step(3, GAMMA)
support = torch.FloatTensor([V_MIN + i*V_DIST for i in range(51)]).cuda()

In [7]:
EPOCH = 2000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        if len(storage) < ST_INIT:
            continue
            
        sample, indices, weights = storage.sample(BATCH)
        weights_ = torch.FloatTensor(weights).unsqueeze(1)
        obs, act_v, act, next_obs, rew, done, etc, unroll_n = list(zip(*sample))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        unroll_n_ = torch.FloatTensor(unroll_n).unsqueeze(1).cuda()

        #Critic update
        q_pred_prob = critic(obs_, act_v_)
        q_pred = F.softmax(q_pred_prob)
        
        q_next_prob = critic_tgt(next_obs_, actor_tgt(next_obs_))
        q_next = F.softmax(q_next_prob)
        q_next[done] = 0.
        
        q_target = transform_dist(q_next, rew_, GAMMA, unroll_n_)
    print()

0 17.00000 
1 11.00000 
2 10.00000 
3 15.00000 
4 9.00000 
5 9.00000 
6 16.00000 
7 12.00000 
8 11.00000 
9 10.00000 
10 9.00000 
11 13.00000 
12 13.00000 
13 10.00000 
14 12.00000 
15 25.00000 
16 16.00000 
17 12.00000 
18 12.00000 
19 9.00000 
20 15.00000 
21 9.00000 
22 10.00000 
23 16.00000 
24 23.00000 
25 9.00000 
26 18.00000 
27 10.00000 
28 14.00000 
29 9.00000 
30 20.00000 
31 10.00000 
32 11.00000 
33 13.00000 
34 23.00000 
35 9.00000 
36 14.00000 
37 20.00000 
38 10.00000 
39 12.00000 
40 19.00000 
41 11.00000 
42 14.00000 
43 12.00000 
44 15.00000 
45 14.00000 
46 9.00000 
47 25.00000 
48 9.00000 
49 11.00000 
50 9.00000 
51 12.00000 
52 9.00000 
53 9.00000 
54 12.00000 
55 11.00000 
56 10.00000 
57 32.00000 
58 25.00000 
59 33.00000 
60 10.00000 
61 12.00000 
62 10.00000 
63 12.00000 
64 9.00000 
65 11.00000 
66 10.00000 
67 8.00000 
68 35.00000 
69 20.00000 
70 10.00000 
71 31.00000 
72 9.00000 
73 10.00000 
74 14.00000 
75 10.00000 
76 31.00000 
77 20.00000 
78 9.00000 


KeyboardInterrupt: 