In [1]:
import random
from collections import deque
from copy import deepcopy

import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [2]:
SEED = 1
BATCH_SIZE = 256
LR = 0.0003
UP_COEF = 0.05
GAMMA = 0.99
EPS = 1.1920929e-07
V_MAX = 10
V_MIN = -10
N_ATOMS = 51
DELTA_Z = (V_MAX - V_MIN) / (N_ATOMS - 1)

# set device
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

# random seed
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if use_cuda:
    torch.cuda.manual_seed_all(SEED)

In [3]:
EPS

1.1920929e-07

In [None]:
class CategoricalDuelingDQN(nn.Module):
    def __init__(self, obs_space, action_space, n_atoms):
        super().__init__()

        self.head = nn.Sequential(
            nn.Linear(obs_space, 256),
            nn.SELU()
        )

        self.val = nn.Sequential(
            nn.Linear(256, 256),
            nn.SELU(),
            nn.Linear(256, n_atoms)
        )

        self.adv = nn.Sequential(
            nn.Linear(256, 256),
            nn.SELU(),
            nn.Linear(256, action_space * n_atoms)
        )

        self.log_softmax = nn.LogSoftmax(dim=-1)

        self.register_buffer(
            'support', torch.arange(V_MIN, V_MAX + DELTA_Z, DELTA_Z))

    def forward(self, x):
        out = self.head(x)
        val_out = self.val(out).reshape(out.shape[0], 1, N_ATOMS)
        adv_out = self.adv(out).reshape(out.shape[0], -1, N_ATOMS)
        adv_mean = adv_out.mean(dim=1, keepdim=True)
        out = val_out + adv_out - adv_mean
        out = self.log_softmax(out)
        probs = out.exp()

        return probs

In [None]:
losses = []


def learn(net, tgt_net, optimizer, rep_memory):
    net.train()
    tgt_net.train()

    dataloader = DataLoader(rep_memory,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            pin_memory=use_cuda)
    # like a double DQN
    for i, (s, a, r, _s, d) in enumerate(dataloader):
        if i > 0:
            break
        s_batch = s.to(device).float()
        a_batch = a.detach().to(device).long()
        _s_batch = _s.to(device).float()
        r_batch = r.detach().to(device).float()
        is_done = 1. - d.detach().to(device).float()

        _p_batch = net(_s_batch)
        _weights = _p_batch * net.support
        _q_batch = _weights.sum(dim=2)
        _a_batch = torch.argmax(_q_batch, dim=1)

        with torch.no_grad():
            _p_batch_tgt = tgt_net(_s_batch)
        _p_best = _p_batch_tgt[range(BATCH_SIZE), _a_batch]
        _p_proj = projection(_p_best, r_batch, is_done)

        p_batch = net(s_batch)
        p_acting = p_batch[range(BATCH_SIZE), a_batch.data]

        # loss
        loss = -(_p_proj * torch.clamp(p_acting, min=EPS).log()).sum(dim=1).mean()
#         loss = -(_p_proj * (p_acting + EPS).log()).sum(dim=1).mean()
        losses.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def projection(_p_best, r_batch, is_done):
    with torch.no_grad():
        _p_proj = np.zeros((BATCH_SIZE, N_ATOMS), dtype=np.float32)
        r_batch_np = r_batch.cpu().numpy()
        is_done_np = is_done.cpu().numpy()
        _p_best_np = _p_best.cpu().numpy()
        
    batch_id = range(BATCH_SIZE)
    for i in range(N_ATOMS):
        z = np.clip(r_batch_np + GAMMA * (V_MIN + i * DELTA_Z) * is_done_np,
                    V_MIN, V_MAX)
        b = (z - V_MIN) / DELTA_Z
        l = np.floor(b).astype(np.int64)
        u = np.ceil(b).astype(np.int64)
        
        _p_proj[batch_id, l[batch_id]] += _p_best_np[batch_id, i] * (u - b)[batch_id]
        _p_proj[batch_id, u[batch_id]] += _p_best_np[batch_id, i] * (b - l)[batch_id]
        
#     _p_proj = np.clip(_p_proj, EPS, None)
#     _p_proj = _p_proj / _p_proj.sum(axis=1, keepdims=1)
    return torch.tensor(_p_proj).to(device).float()


def select_action(obs, tgt_net):
    tgt_net.eval()
    with torch.no_grad():
        state = torch.tensor([obs]).to(device).float()
        probs = target_net(state)
        weights = probs * net.support
        q = weights.sum(dim=2)
        action = torch.argmax(q, dim=1)

    return action.item()

## Main

In [None]:
# make an environment
# env = gym.make('CartPole-v0')
env = gym.make('CartPole-v1')
# env = gym.make('MountainCar-v0')
# env = gym.make('LunarLander-v2')

env.seed(SEED)
obs_space = env.observation_space.shape[0]
action_space = env.action_space.n

# hyperparameter
n_episodes = 1000
learn_start = 1500
memory_size = 50000
update_frq = 1
use_eps_decay = False
epsilon = 0.001
eps_min = 0.001
decay_rate = 0.0001
n_eval = env.spec.trials

# global values
total_steps = 0
learn_steps = 0
rewards = []
reward_eval = deque(maxlen=n_eval)
is_learned = False
is_solved = False

# make two nerual networks
net = CategoricalDuelingDQN(obs_space, action_space, N_ATOMS).to(device)
target_net = deepcopy(net)

# make a optimizer
optimizer = optim.Adam(net.parameters(), lr=LR, eps=EPS)

# make memory
rep_memory = deque(maxlen=memory_size)

In [None]:
use_cuda

True

In [None]:
env.spec.max_episode_steps

500

In [None]:
env.spec.trials

100

In [None]:
env.spec.reward_threshold

475.0

In [None]:
# play
for i in range(1, n_episodes + 1):
    obs = env.reset()
    done = False
    ep_reward = 0
    while not done:
#         env.render()
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            action = select_action(obs, target_net)

        _obs, reward, done, _ = env.step(action)

        rep_memory.append((obs, action, reward, _obs, done))

        obs = _obs
        total_steps += 1
        ep_reward += reward

        if use_eps_decay:
            epsilon -= epsilon * decay_rate
            epsilon = max(eps_min, epsilon)

        if len(rep_memory) >= learn_start:
            if len(rep_memory) == learn_start:
                print('\n============  Start Learning  ============\n')
            learn(net, target_net, optimizer, rep_memory)
            learn_steps += 1

        if learn_steps == update_frq:
            # target smoothing update
            for t, n in zip(target_net.parameters(), net.parameters()):
                t.data = UP_COEF * n.data + (1 - UP_COEF) * t.data
            learn_steps = 0

    if done:
        rewards.append(ep_reward)
        reward_eval.append(ep_reward)
        print('{:3} Episode in {:5} steps, reward {:.2f}'.format(
            i, total_steps, ep_reward))

        if len(reward_eval) >= n_eval:
            if np.mean(reward_eval) >= env.spec.reward_threshold:
                print('\n{} is sloved! {:3} Episode in {:3} steps'.format(
                    env.spec.id, i, total_steps))
                torch.save(target_net.state_dict(),
                           f'./test/saved_models/{env.spec.id}_ep{i}_clear_model_cdddqn.pt')
                break
env.close()

  1 Episode in    10 steps, reward 10.00
  2 Episode in    18 steps, reward 8.00
  3 Episode in    28 steps, reward 10.00
  4 Episode in    38 steps, reward 10.00
  5 Episode in    48 steps, reward 10.00
  6 Episode in    58 steps, reward 10.00
  7 Episode in    67 steps, reward 9.00
  8 Episode in    78 steps, reward 11.00
  9 Episode in    88 steps, reward 10.00
 10 Episode in    98 steps, reward 10.00
 11 Episode in   107 steps, reward 9.00
 12 Episode in   115 steps, reward 8.00
 13 Episode in   124 steps, reward 9.00
 14 Episode in   133 steps, reward 9.00
 15 Episode in   143 steps, reward 10.00
 16 Episode in   153 steps, reward 10.00
 17 Episode in   163 steps, reward 10.00
 18 Episode in   172 steps, reward 9.00
 19 Episode in   181 steps, reward 9.00
 20 Episode in   190 steps, reward 9.00
 21 Episode in   198 steps, reward 8.00
 22 Episode in   207 steps, reward 9.00
 23 Episode in   215 steps, reward 8.00
 24 Episode in   226 steps, reward 11.00
 25 Episode in   236 steps, 

203 Episode in  1899 steps, reward 8.00
204 Episode in  1907 steps, reward 8.00
205 Episode in  1917 steps, reward 10.00
206 Episode in  1926 steps, reward 9.00
207 Episode in  1935 steps, reward 9.00
208 Episode in  1945 steps, reward 10.00
209 Episode in  1953 steps, reward 8.00
210 Episode in  1962 steps, reward 9.00
211 Episode in  1972 steps, reward 10.00
212 Episode in  1980 steps, reward 8.00
213 Episode in  1990 steps, reward 10.00
214 Episode in  1998 steps, reward 8.00
215 Episode in  2007 steps, reward 9.00
216 Episode in  2017 steps, reward 10.00
217 Episode in  2027 steps, reward 10.00
218 Episode in  2035 steps, reward 8.00
219 Episode in  2044 steps, reward 9.00
220 Episode in  2054 steps, reward 10.00
221 Episode in  2062 steps, reward 8.00
222 Episode in  2071 steps, reward 9.00
223 Episode in  2079 steps, reward 8.00
224 Episode in  2089 steps, reward 10.00
225 Episode in  2099 steps, reward 10.00
226 Episode in  2109 steps, reward 10.00
227 Episode in  2119 steps, re

401 Episode in 49733 steps, reward 500.00
402 Episode in 50038 steps, reward 305.00
403 Episode in 50538 steps, reward 500.00
404 Episode in 51038 steps, reward 500.00
405 Episode in 51322 steps, reward 284.00
406 Episode in 51759 steps, reward 437.00
407 Episode in 52259 steps, reward 500.00
408 Episode in 52731 steps, reward 472.00
409 Episode in 53111 steps, reward 380.00
410 Episode in 53512 steps, reward 401.00
411 Episode in 53921 steps, reward 409.00
412 Episode in 54266 steps, reward 345.00
413 Episode in 54549 steps, reward 283.00
414 Episode in 55049 steps, reward 500.00
415 Episode in 55426 steps, reward 377.00
416 Episode in 55684 steps, reward 258.00
417 Episode in 56017 steps, reward 333.00
418 Episode in 56333 steps, reward 316.00
419 Episode in 56833 steps, reward 500.00
420 Episode in 57096 steps, reward 263.00
421 Episode in 57364 steps, reward 268.00
422 Episode in 57664 steps, reward 300.00
423 Episode in 58025 steps, reward 361.00
424 Episode in 58393 steps, reward

597 Episode in 97763 steps, reward 211.00
598 Episode in 97944 steps, reward 181.00
599 Episode in 98131 steps, reward 187.00
600 Episode in 98321 steps, reward 190.00
601 Episode in 98491 steps, reward 170.00
602 Episode in 98811 steps, reward 320.00
603 Episode in 98970 steps, reward 159.00
604 Episode in 99193 steps, reward 223.00
605 Episode in 99363 steps, reward 170.00
606 Episode in 99540 steps, reward 177.00
607 Episode in 99693 steps, reward 153.00
608 Episode in 99876 steps, reward 183.00
609 Episode in 100070 steps, reward 194.00
610 Episode in 100228 steps, reward 158.00
611 Episode in 100433 steps, reward 205.00
612 Episode in 100601 steps, reward 168.00
613 Episode in 100805 steps, reward 204.00
614 Episode in 100982 steps, reward 177.00
615 Episode in 101173 steps, reward 191.00
616 Episode in 101359 steps, reward 186.00
617 Episode in 101545 steps, reward 186.00
618 Episode in 101738 steps, reward 193.00
619 Episode in 101991 steps, reward 253.00
620 Episode in 102219 s

788 Episode in 145273 steps, reward 168.00
789 Episode in 145560 steps, reward 287.00
790 Episode in 145817 steps, reward 257.00
791 Episode in 146031 steps, reward 214.00
792 Episode in 146214 steps, reward 183.00
793 Episode in 146464 steps, reward 250.00
794 Episode in 146698 steps, reward 234.00
795 Episode in 146952 steps, reward 254.00
796 Episode in 147196 steps, reward 244.00
797 Episode in 147407 steps, reward 211.00
798 Episode in 147604 steps, reward 197.00
799 Episode in 147816 steps, reward 212.00
800 Episode in 148012 steps, reward 196.00
801 Episode in 148121 steps, reward 109.00
802 Episode in 148340 steps, reward 219.00
803 Episode in 148527 steps, reward 187.00
804 Episode in 148694 steps, reward 167.00
805 Episode in 148902 steps, reward 208.00
806 Episode in 149067 steps, reward 165.00
807 Episode in 149210 steps, reward 143.00
808 Episode in 149364 steps, reward 154.00
809 Episode in 149582 steps, reward 218.00
810 Episode in 149921 steps, reward 339.00
811 Episode

In [None]:
plt.figure(figsize=(15, 5))
plt.title('Reward')
plt.plot(rewards)
plt.figure(figsize=(15, 5))
plt.title('Loss')
plt.plot(losses)
plt.show()

In [None]:
[
    ('CartPole-v0', 385, 0.05),
    ('CartPole-v1', None, 0.05),
    ('MountainCar-v0', None, 0.1),
    ('LunarLander-v2', None, 0.1)
]