- references
    - https://github.com/DeepRLChinese/DeepRL-Chinese/blob/master/09_trpo.py（运行有问题）
    - https://medium.com/@vladogim97/trpo-minimal-pytorch-implementation-859e46c4232e
        - https://gist.github.com/elumixor/c16b7bdc38e90aa30c2825d53790d217
- 对于 DRL 而言
    - 神经网络反而是简单的，就是一个超强的 function approximator；训练一个 deep neural network，就是学习一个函数近似器
        - $\pi_\theta(\cdot|s)=\pi_\theta(a|s)$
        - $V(s)$
    - 且在 DRL 的问题及应用里，我们需要更灵活多样地组织 learning/training 的 pipeline；

In [1]:
from collections import namedtuple
import gym
import torch
from torch import nn
from torch.distributions import Categorical
from torch.optim import Adam
from IPython.display import Image
gym.__version__

'0.15.4'

## 基本概念

### gym

In [2]:
env = gym.make('CartPole-v0')

state_size = env.observation_space.shape[0]
num_actions = env.action_space.n

In [3]:
(type(env.observation_space), env.observation_space.shape), (type(env.action_space), env.action_space)

((gym.spaces.box.Box, (4,)), (gym.spaces.discrete.Discrete, Discrete(2)))

### Rollout

In [3]:
# 是 type，而非 instance
# s, a, r, s'
Rollout = namedtuple('Rollout', ['states', 'actions', 'rewards', 'next_states', ])

### Actor Critic 

- actor: $\pi_\theta(a|s)$
- critic: value function
    - advantage estimation
    - 可以是 action-value（Q value），也可以是 state-value，$V(s)$（V value）
- advantage estimation


In [4]:
actor_hidden = 32
actor = nn.Sequential(nn.Linear(state_size, actor_hidden),
                      nn.ReLU(),
                      nn.Linear(actor_hidden, num_actions),
                      nn.Softmax(dim=1))

# 依概率分布进行采样
def get_action(state):
    state = torch.tensor(state).float().unsqueeze(0)  # Turn state into a batch with a single element
    dist = Categorical(actor(state))  # Create a distribution from probabilities for actions
    return dist.sample().item()

In [7]:
critic_hidden = 32
critic = nn.Sequential(nn.Linear(state_size, critic_hidden),
                       nn.ReLU(),
                       nn.Linear(critic_hidden, 1))
critic_optimizer = Adam(critic.parameters(), lr=0.005)

def update_critic(advantages):
    loss = .5 * (advantages ** 2).mean()  # MSE
    critic_optimizer.zero_grad()
    loss.backward()
    critic_optimizer.step()

In [8]:
Image(url='../imgs/policy_value_update_summary.png', width=400)

$$
\begin{split}
A_t&=Q_t(s_t,a_t)-V(s_t)\\
&\approx R_{t+1}+\gamma V(s_{t+1}) -V(s_t) \qquad \text{TD(0)}
\end{split}
$$

In [5]:
def estimate_advantages(states, last_state, rewards):
    values = critic(states)
    last_value = critic(last_state.unsqueeze(0))
    next_values = torch.zeros_like(rewards)
    for i in reversed(range(rewards.shape[0])):
        last_value = next_values[i] = rewards[i] + 0.99 * last_value
    advantages = next_values - values
    return advantages

## update_agent

In [7]:
def update_agent(rollouts):
    states = torch.cat([r.states for r in rollouts], dim=0)
    actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()

    advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
    advantages = torch.cat(advantages, dim=0).flatten()

    # Normalize advantages to reduce skewness and improve convergence
    advantages = (advantages - advantages.mean()) / advantages.std()  
    
    update_critic(advantages)

    distribution = actor(states)

    # Important! We clamp the probabilities, so they do not reach zero
    distribution = torch.distributions.utils.clamp_probs(distribution)
    
    probabilities = distribution[range(distribution.shape[0]), actions]

## training pipeline

In [7]:
def train(epochs=100, num_rollouts=10):
    mean_total_rewards = []
    global_rollout = 0
    
    for epoch in range(epochs):
        
        rollouts = []
        rollout_total_rewards = []
        
        for t in range(num_rollouts):
            state = env.reset()
            done = False

            samples = []

            # 一次 trajectory
            while not done:
                with torch.no_grad():
                    action = get_action(state)
                    next_state, reward, done, _ = env.step(action)

                    # Collect samples
                    samples.append((state, action, reward, next_state))

                    state = next_state

            # Transpose our samples
            states, actions, rewards, next_states = zip(*samples)

            states = torch.stack([torch.from_numpy(state) for state in states], dim=0).float()
            next_states = torch.stack([torch.from_numpy(state) for state in next_states], dim=0).float()
            actions = torch.as_tensor(actions).unsqueeze(1)
            rewards = torch.as_tensor(rewards).unsqueeze(1)

            rollouts.append(Rollout(states, actions, rewards, next_states))
            rollout_total_rewards.append(rewards.sum().item())
            global_rollout += 1
            
        update_agent(rollouts)
        mtr = np.mean(rollout_total_rewards)
        print(f'E: {epoch}.\tMean total reward across {num_rollouts} rollouts: {mtr}')
        mean_total_rewards.append(mtr)
        
    plt.plot(mean_total_rewards)
    plt.show()