# 03 Algorithms

## 3.6 Sarsa & n-step Sarsa

### Sarsa
既然可以使用Temporal-Difference算法来估计状态价值，那么也可以直接使用其来估计动作价值。这一点非常重要，因为可以将估计的动作价值与策略提升相结合，来学习最优策略。

给定策略$\pi$，我们的目标是：对于所有的$s \in \mathcal{S}$和动作$a \in \mathcal{A}(s)$，估计状态价值$q_{\pi}(s, a)$。

假设我们有遵循策略$\pi$生成的样本集$ (s_0, a_0, r_1, s_1, a_1, r_2, s_2, ... ,s_t, a_t, r_{t+1}, s_{t+1}, a_{t+1}, ...) $，也可以表示为$\{(s_i, a_i, r_{i+1}, s_{i+1})\}^{T}_{i=0}$。

Sarsa算法利用这些样本来估计动作价值:
$$
\begin{cases}
q_{t+1}(s_t, a_t) = q_t(s_t) - \alpha_t(s_t, a_t)[q_t(s_t,a_t) - (r_{t+1} + \gamma q_t(s_{t+1}, a_{t+1}))] & (s, t)=(s_t, a_t) \\
q_{t+1}(s, a) = q_t(s, a) & (s, a) \neq (s_t, a_t)
\end{cases}
$$

其中，$t=0,1,2,...$，$\alpha_t(s_t, a_t)$是一个很小的正数，代表学习率。

Sarsa的推导过程与Temporal-Difference类似，感兴趣的话可以回顾上一章节。

### Sarsa Algorithms
- 初始化：对于所有的状态-动作对$(s,a)$和时刻$t$初始化$\alpha_t(s,a)=\alpha>0$、初始化$q_0(s,a)$、初始化贪婪策略$\pi_0$，$\epsilon \in (0,1)$
- 对于每一个episode:
- $\qquad$ 遵循策略$\pi_0(s_0)$在$t_0$时刻生成动作$a_0$
- $\qquad$ 如果$s_t(t=0,1,2,\cdots)$不是目标状态，则：
- $\qquad\qquad$ 给定$(s_t,a_t)$采样得到$(r_{t+1},s_{t+1}, a_{t+1})$，其中$a_{t+1}$遵循$\pi_t(s_{t+1})$
- $\qquad\qquad$ 更新动作价值：$q_{t+1}(s_t, a_t) \leftarrow q_{t}(s_t, a_t)-\alpha_{t}(s_t, a_t)[q_{t}(s_t,a_t) - (r_{t+1}+\gamma q_{t}(s_{t+1},a_{t+1}))]$
- $\qquad\qquad$ 更新策略：如果$a = \arg\max_a q_{t+1}(s_t, a)$， $\pi_{t+1}(a|s_t)=1-\frac{\epsilon}{|\cal A(s_t)|}(|\cal A(s_t)| - 1)$，否则$\pi_{t+1}(a|s_t) = \frac{\epsilon}{|\cal A(s_t)|}$
- $\qquad$ $s_t \leftarrow s_{t+1}$, $a_t \leftarrow a_{t+1}$

### Example

In [1]:
import time

import numpy as np
import gymnasium as gym
from tqdm import tqdm

In [2]:
class Sarsa:
    """ Sarsa Algorithm """

    def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1, epsilon_decay=0.99):

        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay

        self.returns = []
        self.q_tables = np.zeros((env.observation_space.n, env.action_space.n))
        self.policy = np.ones((env.observation_space.n, env.action_space.n)) / env.action_space.n

    @staticmethod
    def custom_reward(done, reward):
        if done and reward == 1:
            return 10
        elif done and reward == 0:
            return -5
        else:
            return -0.1

    def take_action(self, state):
        """ Take an epsilon-greedy action based on the Q-table """

        if np.random.rand() < self.epsilon:
            return np.random.choice(range(self.env.action_space.n), p=self.policy[state])
        else:
            return np.argmax(self.q_tables[state])

    def best_action(self, state):
        """ Return the best action based on the Q-table """
        return np.argmax(self.q_tables[state])

    def update_policy_and_values(self, state, action, reward, next_state, next_action):
        td_error = self.q_tables[state][action] - (reward + self.gamma * self.q_tables[next_state][next_action])
        self.q_tables[state][action] -= self.alpha * td_error

        best_action = self.best_action(state)
        policy = np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n
        policy[best_action] = 1 - self.epsilon / self.env.action_space.n * (self.env.action_space.n - 1)
        self.policy[state] = policy

    def train(self, episodes=1000):
        for i in range(10):
            with tqdm(total=episodes // 10, desc=f'Episode {i + 1}') as pbar:
                for episode in range(episodes // 10):
                    state, info = self.env.reset()
                    action = self.take_action(state)
                    done = False

                    gamma_power = 1
                    episode_return = 0
                    while not done:
                        next_state, reward, terminated, truncated, info = self.env.step(action)
                        next_action = self.take_action(next_state)

                        done = terminated or truncated
                        reward = self.custom_reward(done, reward)

                        self.update_policy_and_values(state, action, reward, next_state, next_action)
                        state, action = next_state, next_action

                        episode_return += reward * gamma_power
                        gamma_power *= self.gamma

                    self.returns.append(episode_return)
                    if (episode + 1) % 10 == 0:
                        pbar.set_postfix(
                            {
                                'epoch': episodes / 10 * i + episode + 1,
                                'return': np.mean(self.returns),
                                'epsilon': self.epsilon
                            }
                        )
                    pbar.update(1)

                    self.epsilon *= self.epsilon_decay
                    self.epsilon = max(self.epsilon, 0.01)

    def visualize_policy(self, delay=0.5):
        state, info = self.env.reset()
        done = False

        while not done:
            self.env.render()
            action = np.argmax(self.policy[state])
            state, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            time.sleep(delay)

        self.env.render()
        self.env.close()

In [3]:
environment = gym.make('FrozenLake-v1', desc=None, map_name='4x4', is_slippery=True, render_mode='human')
environment.reset()

2025-02-16 17:53:37.874 python[96012:5287876] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-16 17:53:37.874 python[96012:5287876] +[IMKInputSession subclass]: chose IMKInputSession_Modern


(0, {'prob': 1})

In [4]:
agent = Sarsa(environment, gamma=0.9, epsilon=0.99, alpha=0.1, epsilon_decay=0.99)

In [5]:
agent.train(100)
print(f"Optimal policy: {agent.policy}")
print(f"Optimal Q-tables: {agent.q_tables}")

Episode 1: 100%|██████████| 10/10 [00:18<00:00,  1.88s/it, epoch=10, return=-3.4, epsilon=0.904]
Episode 2: 100%|██████████| 10/10 [00:21<00:00,  2.13s/it, epoch=20, return=-3.37, epsilon=0.818]
Episode 3: 100%|██████████| 10/10 [00:18<00:00,  1.83s/it, epoch=30, return=-3.39, epsilon=0.74]
Episode 4: 100%|██████████| 10/10 [00:31<00:00,  3.19s/it, epoch=40, return=-3.16, epsilon=0.669]
Episode 5: 100%|██████████| 10/10 [00:41<00:00,  4.19s/it, epoch=50, return=-2.97, epsilon=0.605]
Episode 6: 100%|██████████| 10/10 [00:50<00:00,  5.04s/it, epoch=60, return=-2.85, epsilon=0.547]
Episode 7: 100%|██████████| 10/10 [00:45<00:00,  4.52s/it, epoch=70, return=-2.7, epsilon=0.495]
Episode 8: 100%|██████████| 10/10 [00:44<00:00,  4.47s/it, epoch=80, return=-2.54, epsilon=0.448]
Episode 9: 100%|██████████| 10/10 [00:44<00:00,  4.49s/it, epoch=90, return=-2.38, epsilon=0.405]
Episode 10:  10%|█         | 1/10 [00:06<01:00,  6.76s/it]2025-02-16 17:59:04.003 python[96012:5287876] TSM AdjustCapsLoc

Optimal policy: [[0.72547574 0.09150809 0.09150809 0.09150809]
 [0.09336607 0.09336607 0.09336607 0.71990179]
 [0.71990179 0.09336607 0.09336607 0.09336607]
 [0.09526178 0.09526178 0.09526178 0.71421466]
 [0.72547574 0.09150809 0.09150809 0.09150809]
 [0.25       0.25       0.25       0.25      ]
 [0.71990179 0.09336607 0.09336607 0.09336607]
 [0.25       0.25       0.25       0.25      ]
 [0.09150809 0.09150809 0.09150809 0.72547574]
 [0.09150809 0.72547574 0.09150809 0.09150809]
 [0.72547574 0.09150809 0.09150809 0.09150809]
 [0.25       0.25       0.25       0.25      ]
 [0.25       0.25       0.25       0.25      ]
 [0.09150809 0.09150809 0.72547574 0.09150809]
 [0.09150809 0.72547574 0.09150809 0.09150809]
 [0.25       0.25       0.25       0.25      ]]
Optimal Q-tables: [[-0.98674919 -1.00040282 -1.20326934 -1.0142414 ]
 [-2.43332946 -1.05580691 -1.26019265 -0.88527454]
 [-0.85879455 -0.87508781 -0.8625635  -0.92430642]
 [-1.66224271 -1.23718936 -2.0627924  -0.82230948]
 [-0.9838




In [6]:
agent.visualize_policy(delay=0.005)

### n-step Sarsa
我们来回顾一下动作价值的定义：
$$
q_{\pi}(s, a) = \mathbb{E}_{\pi}[G_t | S_t=s, A_t=a]
$$
其中 $G_t$ 是在时间步$t$之后收到的截断回报：
$$
G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_3 + \cdots
$$

事实上，可以对截断回报进行展开：
$$
\begin{align*}
Sarsa \leftarrow G_t^1 &= R_{t+1} + \gamma q_{\pi}(S_{t+1}, A_{t+1}) \\
G_t^2 &= R_{t+1} + \gamma R_{t+2} + \gamma^2 q_{\pi}(S_{t+2}, A_{t+2}) \\
&\vdots \\
n-step \ Sarsa \leftarrow G_t^n &= R_{t+1} + \gamma R_{t+2} + \cdots + \gamma^{n-1} R_{t+n} + \gamma^n q_{\pi}(S_{t+n}, A_{t+n}) \\
&\vdots \\
Monte \ Carlo \leftarrow G_t^{\infty} &= R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \cdots
\end{align*}
$$


**当$n=1$时（Sarsa）**：
$$
G_t^1 = \mathbb{E} [G_t^1|s,a] = \mathbb{E} [R_{t+1} + \gamma q_{\pi}(S_{t+1}, A_{t+1})|S=s,A=a]
$$
相应地，根据Robbins-Monro算法进行求解，可以得到近似求解算法：
$$
q_{t+1}(s_t, a_t) = q_t(s_t, a_t) - \alpha [q_t(s_t, a_t) - (R_{t+1} + \gamma q_t(S_{t+1}, A_{t+1}))]
$$


**当$n=\infty$时（Monte Carlo）**：
$$
q_{\pi}(s,a) = \mathbb{E}_{\pi}[G_{t}^{\infty}|S=s,A=a] = \mathbb{E}_{\pi}[R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + ... |S=s,A=a]
$$
相应地，求解算法：
$$
q_{t+1}(s_t, a_t) = r_{t+1} + \gamma r_{t+1} + \gamma^2 r_{t+3} ...
$$


**当$\infty > n >1$时（n-step Sarsa）**：
$$
q_{\pi}(s,a) = \mathbb{E}_{\pi}[G_{t}^{n}|S=s,A=a] = \mathbb{E}_{\pi}[R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + ... + \gamma^{n}q_{\pi}(S_{t+n},A_{t+n}) |S=s,A=a]
$$
相应地，根据Robbins-Monro算法进行求解，可以得到近似求解算法：
$$
q_{t+1}(s_t, a_t) = q_{t}(s_t, a_t) - \alpha_t(s_t, a_t)[q_{t}(s_t, a_t) - (r_{t+1} + \gamma r_{t+2} + ... + \gamma^{n-1}r_{t+n-1} + \gamma^n q_{t}(s_{t+n}, a_{t+n}))]
$$
由于在时刻$t$无法采集到$(r_{t+n}, s_{t+n}, a_{t+n})$，根据n-step Sarsa要求，直到$t+n$时刻$q_{t+1}(s,a)$才能被更新。

因此，对上式进行重写：
$$
q_{t+n}(s_t, a_t) = q_{t+n-1}(s_t, a_t) - \alpha_{t+n-1}(s_t, a_t)[q_{t+n-1}(s_t, a_t) - (r_{t+1} + \gamma r_{t+2} + ... + \gamma^{n-1}r_{t+n-1} + \gamma^n q_{t+n-1}(s_{t+n}, a_{t+n}))]
$$
其中，$q_{t+n}(s_t, a_t)$是$q_{\pi}(s_t, a_t)$在时刻$t+n$的估计。

### Example

In [7]:
class Sarsas:
    """ n-step Sarsa algorithm """

    def __init__(self, env, steps=20, alpha=0.1, gamma=0.95, epsilon=0.1, epsilon_decay=0.99):
        self.env = env
        self.steps = steps
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay

        self.states = []
        self.actions = []
        self.rewards = []

        self.returns = []
        self.q_tables = np.zeros((env.observation_space.n, env.action_space.n))
        self.policy = np.ones((env.observation_space.n, env.action_space.n)) / env.action_space.n

    @staticmethod
    def custom_reward(done, reward):
        if done and reward == 1:
            return 10
        elif done and reward == 0:
            return -5
        else:
            return -0.1

    def take_action(self, state):
        """ Take an epsilon-greedy action based on the Q-table """

        if np.random.rand() < self.epsilon:
            return np.random.choice(range(self.env.action_space.n), p=self.policy[state])
        else:
            return np.argmax(self.q_tables[state])

    def best_action(self, state):
        """ Return the best action based on the Q-table """
        return np.argmax(self.q_tables[state])

    def update_policy_and_values(self, state, action, reward, next_state, next_action, done):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)

        g = self.q_tables[next_state][next_action]
        if len(self.states) == self.steps or (done and len(self.states) > 0):
            for i in reversed(range(len(self.states))):
                g = self.rewards[i] + self.gamma * g
                td_error = g - self.q_tables[self.states[i]][self.actions[i]]
                self.q_tables[self.states[i]][self.actions[i]] -= self.alpha * td_error

        if done:
            self.states = []
            self.actions = []
            self.rewards = []

        best_action = self.best_action(state)
        policy = np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n
        policy[best_action] = 1 - self.epsilon / self.env.action_space.n * (self.env.action_space.n - 1)
        self.policy[state] = policy

    def train(self, episodes=1000):
        for i in range(10):
            with tqdm(total=episodes // 10, desc=f'Episode {i + 1}') as pbar:
                for episode in range(episodes // 10):
                    state, info = self.env.reset()
                    action = self.take_action(state)
                    done = False

                    gamma_power = 1
                    episode_return = 0
                    while not done:
                        next_state, reward, terminated, truncated, info = self.env.step(action)
                        next_action = self.take_action(next_state)

                        done = terminated or truncated
                        reward = self.custom_reward(done, reward)

                        self.update_policy_and_values(state, action, reward, next_state, next_action, done)
                        state, action = next_state, next_action

                        episode_return += reward * gamma_power
                        gamma_power *= self.gamma

                    self.returns.append(episode_return)
                    if (episode + 1) % 10 == 0:
                        pbar.set_postfix(
                            {
                                'epoch': episodes / 10 * i + episode + 1,
                                'return': np.mean(self.returns),
                                'epsilon': self.epsilon
                            }
                        )
                    pbar.update(1)

                    self.epsilon *= self.epsilon_decay
                    self.epsilon = max(self.epsilon, 0.01)

In [8]:
environment = gym.make('FrozenLake-v1', desc=None, map_name='4x4', is_slippery=True, render_mode='human')
environment.reset()

(0, {'prob': 1})

In [9]:
agent = Sarsas(environment, steps=20, gamma=0.9, epsilon=0.99, alpha=0.1, epsilon_decay=0.99)

In [10]:
agent.train(100)
print(f"Optimal policy: {agent.policy}")
print(f"Optimal Q-tables: {agent.q_tables}")

Episode 1: 100%|██████████| 10/10 [00:22<00:00,  2.30s/it, epoch=10, return=-3.25, epsilon=0.904]
Episode 2: 100%|██████████| 10/10 [00:16<00:00,  1.61s/it, epoch=20, return=-3.42, epsilon=0.818]
Episode 3: 100%|██████████| 10/10 [00:11<00:00,  1.10s/it, epoch=30, return=-3.66, epsilon=0.74]
Episode 4: 100%|██████████| 10/10 [00:13<00:00,  1.38s/it, epoch=40, return=-3.7, epsilon=0.669]
Episode 5: 100%|██████████| 10/10 [00:21<00:00,  2.13s/it, epoch=50, return=-3.6, epsilon=0.605]
Episode 6: 100%|██████████| 10/10 [00:17<00:00,  1.76s/it, epoch=60, return=-3.59, epsilon=0.547]
Episode 7: 100%|██████████| 10/10 [00:15<00:00,  1.58s/it, epoch=70, return=-3.61, epsilon=0.495]
Episode 8: 100%|██████████| 10/10 [00:17<00:00,  1.71s/it, epoch=80, return=-3.61, epsilon=0.448]
Episode 9: 100%|██████████| 10/10 [00:15<00:00,  1.53s/it, epoch=90, return=-3.62, epsilon=0.405]
Episode 10: 100%|██████████| 10/10 [00:15<00:00,  1.50s/it, epoch=100, return=-3.63, epsilon=0.366]

Optimal policy: [[0.09150809 0.72547574 0.09150809 0.09150809]
 [0.09336607 0.71990179 0.09336607 0.09336607]
 [0.71421466 0.09526178 0.09526178 0.09526178]
 [0.12495926 0.62512223 0.12495926 0.12495926]
 [0.09150809 0.72547574 0.09150809 0.09150809]
 [0.25       0.25       0.25       0.25      ]
 [0.09817776 0.70546673 0.09817776 0.09817776]
 [0.25       0.25       0.25       0.25      ]
 [0.09150809 0.72547574 0.09150809 0.09150809]
 [0.09150809 0.09150809 0.09150809 0.72547574]
 [0.10533356 0.10533356 0.10533356 0.68399933]
 [0.25       0.25       0.25       0.25      ]
 [0.25       0.25       0.25       0.25      ]
 [0.23774751 0.28675746 0.23774751 0.23774751]
 [0.28675746 0.23774751 0.23774751 0.23774751]
 [0.25       0.25       0.25       0.25      ]]
Optimal Q-tables: [[1.99603384e+01 4.75052680e+06 2.60527521e+01 4.21945377e+01]
 [5.48021705e+00 8.59587619e+03 3.66939938e+00 1.26770259e+01]
 [8.24974743e+01 1.20411592e+00 1.31600557e+00 3.97501270e+00]
 [7.34880000e-01 2.05371




In [11]:
agent.visualize_policy(delay=0.005)

SyntaxError: invalid character '｜' (U+FF5C) (4086545296.py, line 1)