In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from scipy.stats import poisson
from typing import Tuple

sns.set('notebook', font_scale=1.1, rc={'figure.figsize': (8, 4)})
sns.set_style('ticks', rc={'figure.facecolor': 'none', 'axes.facecolor': 'none'})
%config InlineBackend.figure_format = 'svg'
matplotlib.rcParams['figure.facecolor'] = 'white'

### 6.4 Sarsa

Implement Sarsa($\lambda$) in Easy21. Initialize the value function to zero. Use the same step-size and exploration schedules as in the previous section. Run the algorithm with parameter values $\lambda \in \{ 0, 0.1, 0.2, \dots, 1\}$. Stop each run after 1000 episodes and report the mean-squared error (MSE):
\begin{align}
    \sum_{s, a} (Q(s, a) - Q^{*}(s, a))^2
\end{align}
over all states $s$ and actions $a$, comparing the true values $Q^{*}(s, a)$ computed in the previous section with the estimated values $Q(s, a)$ computed by Sarsa. Plot the mean-squared error against $\lambda$. For $\lambda=0$ and $\lambda=1$ only, plot the learning curve of the MSE against the episode number.

---

We have the following pseudocode:

```python
for episode in range(num_episodes):

    initialize state matrix

    while game is not done:

        take action, observe reward and new_state

        choose new_action using fixed policy

        # Update rule
        Q[s, a] += Q[s, a] + alpha * (reward + gamma * (Q[new_state, new_action] - Q[s, a]))

        # Update current state, action
        state = new_state
        action = new_action
```

We implement SARSA, which has the following update rule (Sutton & Barto, section 6.4, page 129):

\begin{align}
    Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha \left[ R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t) \right]
\end{align}

**Note that**
- We do this update for each transition to a nonterminal state (in contrast to MC methods)
- ...

In [None]:
class SARSA(BaseAlgo):
    """TD Learning for policy evaluation.

    Args:
        BaseAlgo (object): Base RL algorithm class.
    """
    def __init__(self, env_dim, N0, lambda_coef, gamma):
        super(SARSA, self).__init__(env_dim, N0)
        self.gamma = gamma
        self.lambda_coef = lambda_coef
        self.Q_star = None
        self.mse_episode = None

    def update_policy(self, obs, action, new_obs, new_action, reward):
        """Method to update the policy.

        Args:
            obs (int): The state
            action, new_action (int): Whether to hit or stick
            new_obs (_type_): The new state
            reward (_type_): Obtained reward for each state
        """
        
        if self.N_table[self.to_index(obs, action)] != 0:
            alpha = 1 / self.N_table[self.to_index(obs, action)]
        else: 
            alpha = 1
        
        # Compute td error (sometimes called delta)      
        td_error = (
            reward 
          + self.gamma
          * self.Q_table[self.to_index(new_obs, new_action)] 
          - self.Q_table[self.to_index(obs, action)]
        )

        # Update Q value
        self.Q_table[self.to_index(obs, action)] += alpha * td_error

    def compute_mse(self, Q_star, episode):
        """Compute mean squared error (MSE) to ground-truth.

        Args:
            Q_star (np.array): Ground-truth obtained through MC learning with many episodes
            episode (int): The episode in which we are.
        """
        self.mse_episode[episode] = np.sum(
            np.square(self.Q - Q_star)) / float(Q_star.size)

In [None]:
num_episodes = 1000

# Initialize environment
easy21 = Easy21(dealer_threshold=17, player_threshold=12)

# Initialize Monte Carlo algorithm
sarsa_learner = SARSA(env_dim=easy21.dim, gamma=.9, lambda_coef=1, N0=100)

for episode_i in range(num_episodes):
    
    #print(f"episode: {episode_i}")
    obs, _ , done, info = easy21._reset()

    while not done:
        if obs[1] < easy21.player_threshold: # Play with fixed policy
            # Always hit
            action = 1 
        else:  
            # Take epsilon greedy action
            action = monte_carlo_learner.take_eps_greedy_action(obs)
            
        # Get new observation
        new_obs, reward, done, info = easy21.step(action)
        obs = new_obs
    
        # Update the policy at the end of the episode
        sarsa_learner.update_policy(obs, action, new_obs, action, reward)

In [None]:
def run_td_lambda(lambda_coef, Q_star, num_episodes=1000, N0=100, gamma=.9):
    
    # Initialize environment
    easy21 = Easy21(
        dealer_threshold=17, 
        player_threshold=12,
    )

    # Initialize SARSA(lambda) learner
    sarsa_learner = SARSA(
        env_dim=easy21.dim, 
        gamma=gamma, 
        lambda_coef=lambda_coef, 
        N0=N0,
    )

    for episode_i in range(num_episodes):
        
        # Initial game state
        obs, _ , done, info = easy21._reset() 
        action = sarsa_learner.take_eps_greedy_action(obs)

        while not done:
            if obs[1] < easy21.player_threshold: # Play with fixed policy
                # Always hit
                action = 1 
            else:  
                # Take epsilon greedy action
                action = monte_carlo_learner.take_eps_greedy_action(obs)
                
            # Get new observation
            new_obs, reward, done, info = easy21.step(action)
            
            # Get new action        
            new_action = sarsa_learner.take_eps_greedy_action()

            # Update the policy
            sarsa_learner.update_policy(obs, action, new_obs, new_action, reward)

            obs = new_obs
            action = new_action

        sarsa_learner.compute_mse(Q_star, episode_i)

    return sarsa_learner.mse_episode