# 06.DQN

1）使用随机权重$（w←1.0）$初始化目标网络$Q(s, a, w)$和网络$\hat Q(s, a, w)$，$Q$和$\hat Q$相同，清空回放缓冲区。

2）以概率ε选择一个随机动作a，否则 $a=argmaxQ(s,a,w)$。

3）在模拟器中执行动作a，观察奖励r和下一个状态s'。

4）将转移过程(s, a, r, s')存储在回放缓冲区中。

5）从回放缓冲区中采样一个随机的小批量转移过程。

6）对于回放缓冲区中的每个转移过程，如果片段在此步结束，则计算目标$y=r$，否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。

7）计算损失：$L=(Q(s, a, w)–y)^2$。

8）固定网络$\hat Q(s, a, w)$不变，通过最小化模型参数的损失，使用SGD算法更新$Q(s, a)$。

9）每N步，将权重从目标网络$Q$复制到$\hat Q(s, a, w)$ 。

10）从步骤2开始重复，直到收敛为止。


In [2]:
import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriter

In [3]:
class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, q_table_size):
        super(Net, self).__init__()

        self.net = nn.Sequential(
            # 输入为状态，样本为（1*n）
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            # nn.Linear(hidden_size, hidden_size),
            # nn.ReLU(),
            nn.Linear(hidden_size, q_table_size),
        )

    def forward(self, state):
        return self.net(state)


class DQN:
    def __init__(self, env, tgt_net, net):
        self.env = env
        self.tgt_net = tgt_net
        self.net = net

    def generate_train_data(self, batch_size, epsilon):

        state, _ = env.reset()
        train_data = []
        while len(train_data)<batch_size*2:
            q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()
            if np.random.uniform(0, 1, 1) > epsilon:
                action = self.env.action_space.sample()
            else:
                action = int(torch.argmax(q_table_tgt))
            new_state, reward,terminated, truncted, info = env.step(action)
            train_data.append([state, action, reward, new_state, terminated])
            state = new_state
            if terminated:
                state, _ = env.reset()
                continue
        random.shuffle(train_data)                
        return train_data[:batch_size]

    def calculate_y_hat_and_y(self, batch):
        # 6）对于回放缓冲区中的每个转移过程，如果片段在此步结束，则计算目标$y=r$，否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。
        y = []
        state_space = []
        action_space = []
        for state, action, reward, new_state, terminated in batch:
            # y值
            if terminated:
                y.append(reward)
            else:
                # 下一步的 qtable 的最大值
                q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()
                y.append(reward + gamma * float(torch.max(q_table_net)))
            # y hat的值
            state_space.append(state)
            action_space.append(action)
        idx = [list(range(len(action_space))), action_space]
        y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]
        return y_hat, torch.tensor(y)

    def update_net_parameters(self, update=True):
        self.net.load_state_dict(self.tgt_net.state_dict())


In [3]:
# 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)

hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)

# 初始化参数
# dqn.init_net_and_target_net_weight()

# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)


# 定义损失函数
loss = nn.MSELoss()

# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

In [4]:
gamma = 0.8
for i in range(10000):
    batch = dqn.generate_train_data(256, 0.8)
    y_hat, y = dqn.calculate_y_hat_and_y(batch)
    opt.zero_grad()
    l = loss(y_hat, y)
    l.backward()
    opt.step()

    print("MSE: {}".format(l.item()))
    if i % 5 == 0:
        dqn.update_net_parameters(update=True)

MSE: 1.5298908948898315
MSE: 1.077877402305603
MSE: 1.0466701984405518
MSE: 1.0176589488983154
MSE: 0.9971619844436646
MSE: 0.9662405848503113
MSE: 1.0380014181137085
MSE: 1.018809199333191
MSE: 0.9971290826797485
MSE: 0.9604581594467163
MSE: 0.9392753839492798
MSE: 1.0199854373931885
MSE: 0.9792607426643372
MSE: 0.9508342742919922
MSE: 0.9413241147994995
MSE: 0.9178609848022461
MSE: 0.9918165802955627
MSE: 0.9651069641113281
MSE: 0.9283190965652466
MSE: 0.9043717384338379
MSE: 0.8891723155975342
MSE: 0.95167475938797
MSE: 0.9404323697090149
MSE: 0.9061200022697449
MSE: 0.871833324432373
MSE: 0.864561140537262
MSE: 0.9276864528656006
MSE: 0.9229756593704224
MSE: 0.877948522567749
MSE: 0.8488312363624573
MSE: 0.8364596366882324
MSE: 0.9151417016983032
MSE: 0.8785946369171143
MSE: 0.8482170104980469
MSE: 0.8330167531967163
MSE: 0.8029308319091797
MSE: 0.8889617919921875
MSE: 0.8504251837730408
MSE: 0.8235427141189575
MSE: 0.8093757033348083
MSE: 0.7784193158149719
MSE: 0.8525643944740295

KeyboardInterrupt: 

# 预测

In [9]:
env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")

state, info = env.reset()
total_rewards = 0

while True:
    q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()
    # if np.random.uniform(0, 1, 1) > 0.9:
    #     action = env.action_space.sample()
    # else:
    action = int(torch.argmax(q_table_state))
    state, reward, terminated, truncted, info = env.step(action)
    if terminated:
        break

KeyboardInterrupt: 

In [8]:
env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")
env.reset()
for i in range(1000):
    env.step(env.action_space.sample())

KeyboardInterrupt: 