# Policy Gradient Methods

## Average state value
对所有状态值的加权平均
$$
\begin {align*}
\bar v =& \sum_{s\in S}v_\pi (s)\\
=& E_{S-d}(v_\pi(S))\\
=& E_{S-d}(\sum_{a\in A} q(s,a)\pi(a|s))
\end {align*}
$$

策略函数为$\pi(a|s,\theta)$
令$J(\theta)=\bar v$
对其求梯度

$$
\begin {align*}
\nabla_\theta J(\theta) =& E_{S-d}(\sum_{a\in A} q(s,a)\nabla_\theta\pi(a|s,\theta))\\
=& E_{S-d}(\sum_{a\in A} q(s,a)\pi(a|s,\theta) \nabla_\theta ln\pi(a|s,\theta))\\
=& E_{S-d}[E_{a-\pi(S,\Theta)}[q(s,a) \nabla_\theta ln\pi(a|s,\theta)]]\\
=& E_{S-d,a-\pi(S,\Theta)}[q(s,a) \nabla_\theta ln\pi(a|s,\theta)]
\end {align*}
$$

## Average reward
$$\bar r = (1-\gamma)\bar v$$

## Monte Carlo policy gradient (REINFORCE)
- 1、用随机权重初始化策略网络
- 2、运行N个完整的片段，保存其(s,a,r,s')状态转移
- 3、对于每个片段k的每一步t，计算后续步的带折扣的总奖励$Q_{k,t}=\sum_{i=0}\gamma_ir_i$
- 4、计算所有状态转移的损失函数 $L=-\sum_{k,t}Q_{k,t}ln\pi(a_{k,t}|s_{k,t})$
- 5、执行SGD更新权重，以最小化损失
- 6、从步骤2开始重复，直到收敛

In [1]:
import collections
import copy
import math
import random
import time
from collections import defaultdict

import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriter

In [2]:
# 1、用随机权重初始化策略网络
class PolicyNet(nn.Module):
    def __init__(self, obs_n, hidden_num, act_n):
        super().__init__()
        # 动作优势A(s, a)
        self.net = nn.Sequential(
            nn.Linear(obs_n, hidden_num),
            nn.ReLU(),
            nn.Linear(hidden_num, act_n),
            nn.Softmax(dim=1),
        )

    def forward(self, state):
        if len(torch.Tensor(state).size()) == 1:
            state = state.reshape(1, -1)
        return self.net(state)

In [3]:
def discount_reward(R, gamma):
    # r 为历史得分
    n = len(R)
    dr = 0
    for i in range(n):
        dr += gamma**i * R[i]
    return dr

In [4]:
# - 2、运行N个完整的片段，保存其(s,a,r,s')状态转移
def generate_episode(env, n_steps, net, predict=False):
    episode_history = dict()
    r_list = []

    for _ in range(n_steps):
        episode = []
        predict_reward = []
        state, info = env.reset()
        while True:
            p = net(torch.Tensor(state)).detach().numpy().reshape(-1)
            action = np.random.choice(list(range(env.action_space.n)), p=p)
            next_state, reward, terminated, truncted, info = env.step(action)
            episode.append([state, action, next_state, reward, terminated])
            predict_reward.append(reward)
            state = next_state
            if terminated or truncted:
                episode_history[_] = episode
                r_list.append(len(episode))
                episode = []
                predict_reward = []
                break
    if predict:
        return np.mean(r_list)
    return episode_history

In [5]:
# 对于每个片段k的每一步t，计算后续步的带折扣的总奖励
def calculate_t_discount_reward(reward_list, gamma):
    discount_reward = []
    total_reward = 0
    for i in reward_list[::-1]:
        total_reward = total_reward * gamma + i
        discount_reward.append(total_reward)
    return discount_reward[::-1]

- 4、计算所有状态转移的损失函数 $L=-\sum_{k,t}Q_{k,t}ln\pi(a_{k,t}|s_{k,t})$

In [6]:
def loss(batch, gamma):
    l = 0
    for episode in batch.values():
        reward_list = [
            reward for state, action, next_state, reward, terminated in episode
        ]
        state = [state for state, action, next_state, reward, terminated in episode]
        action = [action for state, action, next_state, reward, terminated in episode]
        qt = calculate_t_discount_reward(reward_list, gamma)
        pi = net(torch.Tensor(state))
        pi_a = pi.gather(dim=1, index=torch.LongTensor(action).reshape(-1, 1))
        l -= torch.Tensor(qt) @ torch.log(pi_a)
    return l / len(batch.values())

## 训练

In [7]:
## 初始化环境
env = gym.make("CartPole-v1", max_episode_steps=200)
# env = gym.make("CartPole-v1", render_mode = "human")

state, info = env.reset()

obs_n = env.observation_space.shape[0]
hidden_num = 64
act_n = env.action_space.n
net = PolicyNet(obs_n, hidden_num, act_n)

# 定义优化器
opt = optim.Adam(net.parameters(), lr=0.1)

# 记录
writer = SummaryWriter(log_dir="logs/PolicyGradient/reinforce", comment="test1")

In [8]:
epochs = 200
batch_size = 20
gamma = 0.9

for epoch in range(epochs):
    batch = generate_episode(env, batch_size, net)
    l = loss(batch, gamma)

    # 反向传播
    opt.zero_grad()
    l.backward()
    opt.step()

    writer.add_scalars(
        "Loss",
        {"loss": l.item(), "max_steps": generate_episode(env, 10, net, predict=True)},
        epoch,
    )

    print(
        "epoch:{},  Loss: {}, max_steps: {}".format(
            epoch, l.detach(), generate_episode(env, 10, net, predict=True)
        )
    )

  pi = net(torch.Tensor(state))


epoch:0,  Loss: tensor([78.3558]), max_steps: 24.5
epoch:1,  Loss: tensor([94.8375]), max_steps: 29.7
epoch:2,  Loss: tensor([89.3743]), max_steps: 22.2
epoch:3,  Loss: tensor([58.2732]), max_steps: 34.4
epoch:4,  Loss: tensor([91.8468]), max_steps: 49.6
epoch:5,  Loss: tensor([121.1975]), max_steps: 45.1
epoch:6,  Loss: tensor([125.6664]), max_steps: 45.5
epoch:7,  Loss: tensor([79.8753]), max_steps: 50.0
epoch:8,  Loss: tensor([109.9683]), max_steps: 74.9
epoch:9,  Loss: tensor([89.4750]), max_steps: 81.1
epoch:10,  Loss: tensor([119.8851]), max_steps: 70.2
epoch:11,  Loss: tensor([94.1659]), max_steps: 86.1
epoch:12,  Loss: tensor([84.3775]), max_steps: 62.7
epoch:13,  Loss: tensor([70.3001]), max_steps: 51.2
epoch:14,  Loss: tensor([45.8444]), max_steps: 48.7
epoch:15,  Loss: tensor([47.5268]), max_steps: 37.9
epoch:16,  Loss: tensor([47.2998]), max_steps: 37.2
epoch:17,  Loss: tensor([46.3998]), max_steps: 33.2
epoch:18,  Loss: tensor([35.2964]), max_steps: 32.2
epoch:19,  Loss: t

# 预测

In [9]:
env = gym.make("CartPole-v1", render_mode="human")
env = gym.wrappers.RecordVideo(env, video_folder="video")

state, info = env.reset()
total_rewards = 0

while True:
    p = net(torch.Tensor(state)).detach().numpy().reshape(-1)
    action = np.random.choice(list(range(env.action_space.n)), p=p)
    state, reward, terminated, truncted, info = env.step(action)
    if terminated:
        break

  logger.warn(
  logger.warn(
