# 11.Dueling-DQN

这个对DQN的改进是在2015年的“Dueling Network Architectures for Deep Reinforcement Learning”论文中提出的。

该论文的核心发现是，神经网络所试图逼近的Q值Q(s, a)可以被分成两个量：状态的价值V(s)，以及这个状态下的动作优势A(s, a)。

在同一个状态下，所有动作的优势值之和为，因为所有动作的动作价值的期望就是这个状态的状态价值。

这种约束可以通过几种方法来实施，例如，通过损失函数。但是在论文中，作者提出一个非常巧妙的解决方案，就是从神经网络的Q表达式中减去优势值的平均值，它有效地将优势值的平均值趋于0。

$$
Q(s,a) = V(a)+A(s,a)-\frac{1}{n}\sum _{a'}A(s,a')
$$

这使得对基础DQN的改动变得很
简单：为了将其转换成Dueling DQN，只需要改变神经网络的结构，而
不需要影响其他部分的实现。

In [1]:
import collections
import copy
import math
import random
import time
from collections import defaultdict

import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriter

In [2]:
class DuelingNet(nn.Module):
    def __init__(self, obs_size, hidden_size, q_table_size):
        super().__init__()

        # 动作优势A(s, a)
        self.a_net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, q_table_size),
        )

        # 价值V(s)
        self.v_net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, state):
        if len(torch.Tensor(state).size())==1:
            state = state.reshape(1,-1)
        v = self.v_net(state)
        a = self.a_net(state)
        mean_a = a.mean(dim=1,keepdim=True)
        # torch.mean(a, axis=1).reshape(-1, 1)
        return v + a - mean_a


class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        shape = (env.observation_space.n,)
        self.observation_space = gym.spaces.Box(0.0, 1.0, shape, dtype=np.float32)

    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res

# ReplayBuffer

In [3]:
class ReplayBuffer:
    def __init__(self, queue_size, replay_time):
        self.queue = []
        self.queue_size = queue_size
        self.replay_time = replay_time

    def get_batch_queue(self, env, action_trigger, batch_size, epsilon):
        def insert_sample_to_queue(env):
            state, info = env.reset()
            stop = 0

            while True:
                if np.random.uniform(0, 1, 1) > epsilon:
                    action = env.action_space.sample()
                else:
                    action = action_trigger(state)

                next_state, reward, terminated, truncated, info = env.step(action)
                self.queue.append([state, action, next_state, reward, terminated])
                state = next_state
                if terminated:
                    state, info = env.reset()
                    stop += 1
                    continue
                if stop >= replay_time:
                    break

        def init_queue(env):
            while True:
                insert_sample_to_queue(env)
                if len(self.queue) >= self.queue_size:
                    break

        init_queue(env)
        insert_sample_to_queue(env)
        self.queue = self.queue[-self.queue_size :]

        return random.sample(self.queue, batch_size)

# DQN

In [4]:
class DQN:
    def __init__(self, env, obs_size, hidden_size, q_table_size, net):
        self.env = env
        self.net = net(obs_size, hidden_size, q_table_size)
        self.tgt_net = net(obs_size, hidden_size, q_table_size)

    # 更新net参数
    def update_net_parameters(self, update=True):
        self.net.load_state_dict(self.tgt_net.state_dict())

    def get_action_trigger(self, state):
        state = torch.Tensor(state)
        action = int(torch.argmax(self.tgt_net(state).detach()))
        return action

    # 计算y_hat_and_y
    def calculate_y_hat_and_y(self, batch, gamma):
        y = []
        action_sapce = []
        state_sapce = []

        for state, action, next_state, reward, terminated in batch:
            q_table_net = self.net(torch.Tensor(next_state)).detach()
            y.append(reward + (1 - terminated) * gamma * float(torch.max(q_table_net)))
            action_sapce.append(action)
            state_sapce.append(state)
        y_hat = self.tgt_net(torch.Tensor(np.array(state_sapce)))
        y_hat = y_hat.gather(1, torch.LongTensor(action_sapce).reshape(-1, 1))
        return y_hat.reshape(-1), torch.tensor(y)

    def predict_reward(self):
        state, info = env.reset()
        step = 0
        reward_space = []

        while True:
            step += 1
            state = torch.Tensor(state)
            action = int(torch.argmax(self.net(state).detach()))
            next_state, reward, terminated, truncated, info = env.step(action)
            reward_space.append(reward)
            state = next_state
            if terminated:
                state, info = env.reset()
                continue
            if step >= 100:
                break
        return float(np.mean(reward_space))

## 训练

In [5]:
hidden_size = 64
queue_size = 500
replay_time = 50

## 初始化环境
env = frozen_lake.FrozenLakeEnv(is_slippery=False)
env.spec = gym.spec("FrozenLake-v1")
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
env = DiscreteOneHotWrapper(env)

## 初始化buffer
replay_buffer = ReplayBuffer(queue_size, replay_time)

## 初始化dqn
obs_size = env.observation_space.shape[0]
q_table_size = env.action_space.n
dqn = DQN(env, obs_size, hidden_size, q_table_size, DuelingNet)

# 定义优化器
opt = optim.Adam(dqn.tgt_net.parameters(), lr=0.01)

# 定义损失函数
loss = nn.MSELoss()

writer = SummaryWriter(log_dir="logs/DQN/Dueling-DQN", comment="test1")

In [6]:
batch_size = 256
epsilon = 0.8
epochs = 500
gamma = 0.9

In [7]:
for epoch in range(epochs):
    batch = replay_buffer.get_batch_queue(
        env, dqn.get_action_trigger, batch_size, epsilon
    )
    y_hat, y = dqn.calculate_y_hat_and_y(batch, gamma)
    l = loss(y_hat, y)

    # 反向传播
    opt.zero_grad()
    l.backward()
    opt.step()

    if epoch % 10 == 0 and epoch != 0:
        dqn.update_net_parameters()

    predict_reward = dqn.predict_reward()
    writer.add_scalars(
        "MSE", {"loss": l.item(), "predict_reward": predict_reward}, epoch
    )

    print(
        "epoch:{},  MSE: {}, epsilon: {}, 100 steps reward: {}".format(
            epoch, l, epsilon, predict_reward
        )
    )

epoch:0,  MSE: 0.009532208554446697, epsilon: 0.8, 100 steps reward: 0.0
epoch:1,  MSE: 0.005161388777196407, epsilon: 0.8, 100 steps reward: 0.0
epoch:2,  MSE: 0.008616828359663486, epsilon: 0.8, 100 steps reward: 0.0
epoch:3,  MSE: 0.006230560131371021, epsilon: 0.8, 100 steps reward: 0.0
epoch:4,  MSE: 0.005877045448869467, epsilon: 0.8, 100 steps reward: 0.0
epoch:5,  MSE: 0.0029788087122142315, epsilon: 0.8, 100 steps reward: 0.0
epoch:6,  MSE: 0.001360040856525302, epsilon: 0.8, 100 steps reward: 0.0
epoch:7,  MSE: 0.001736717065796256, epsilon: 0.8, 100 steps reward: 0.0
epoch:8,  MSE: 0.0011424050899222493, epsilon: 0.8, 100 steps reward: 0.0
epoch:9,  MSE: 0.0009270317386835814, epsilon: 0.8, 100 steps reward: 0.0
epoch:10,  MSE: 0.0028675051871687174, epsilon: 0.8, 100 steps reward: 0.0
epoch:11,  MSE: 0.0012744665145874023, epsilon: 0.8, 100 steps reward: 0.0
epoch:12,  MSE: 0.00660003162920475, epsilon: 0.8, 100 steps reward: 0.0
epoch:13,  MSE: 0.005725443828850985, epsilo

# 可视化预测

In [12]:
# DQN_Q = dqn.net

# env = frozen_lake.FrozenLakeEnv(is_slippery=False, render_mode="human")
# env.spec = gym.spec("FrozenLake-v1")
# # display_size = 512
# # env.window_size = (display_size, display_size)
# # env.cell_size = (
# #     env.window_size[0] // env.ncol,
# #     env.window_size[1] // env.nrow,
# # )
# env = gym.wrappers.RecordVideo(env, video_folder="video")

# env = DiscreteOneHotWrapper(env)

# state, info = env.reset()
# total_rewards = 0

# while True:
#     action = int(torch.argmax(DQN_Q(torch.Tensor(state))))
#     state, reward, terminated, truncted, info = env.step(action)
#     print(action)
#     if terminated:
#         break
# env.close()

2
2
1
1
1
2
