# Demo Frozen Lake SARSA Agent
 - Demos the SARSA on-policy RL agent for the Frozen Lake problem

#### Installs

In [11]:
! pip install gymnasium numpy matplotlib pyvirtualdisplay




#### Imports

In [12]:
import numpy as np                                # numerical operations
import gymnasium as gym                           # Gymnasium environments
from pyvirtualdisplay import Display              # headless display
import matplotlib.pyplot as plt                   # plotting
from matplotlib.animation import FuncAnimation    # animation
from IPython import display
from typing import List, Any


#### Define global configs and variables

In [13]:
# Rendering configs
DPI = 72
INTERVAL = 100 # ms

#### Utility class of helper functions

In [14]:
class Helpers:
  """
  Utility class of helper functions
  """
  @staticmethod
  def animateEnvironment(images: List[Any]):
    """
    Animates the environment
    :param images: Images
    """
    plt.figure(
        figsize=(images[0].shape[1]/DPI,images[0].shape[0]/DPI),
        dpi=DPI
        )
    patch = plt.imshow(images[0])
    plt.axis=('off')
    animate = lambda i: patch.set_data(images[i])
    ani = FuncAnimation(
        plt.gcf(),
        animate,
        frames=len(images),
        interval=INTERVAL)
    display.display(display.HTML(ani.to_jshtml()))
    plt.close()

#### Solution 3 Steps:
 - Step 1: Implement SARSA RL (on-policy) agent
 - Step 2: Implement the RL training loop
 - Step 3: Implement the RL evaluation (animation) policy


##### Step 1: SARSA RL (on-policy) implementation

In [15]:
class SARSAAgent:
    def __init__(
        self,
        env,
        alpha=0.1,
        gamma=0.99,
        epsilon=1.0,
        epsilon_decay=0.995,
        min_epsilon=0.01
        ):
        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.min_epsilon = min_epsilon
        n_states = env.observation_space.n
        n_actions = env.action_space.n
        self.Q = np.zeros((n_states, n_actions))

        print(f"Frozen Lake environment creation..")
        print(f"Observation space: {n_states}")
        print(f"Action space: {n_actions}")
        print(f"""SARSA hyperparameters are:
                  \nalpha: {self.alpha}
                  \nepsilon: {self.epsilon}
                  \nepsilon_decay: {self.epsilon_decay}
                  \nmin_epsilon: {self.min_epsilon}\n""")

    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            return self.env.action_space.sample()
        return np.argmax(self.Q[state])

    def update(self, s, a, r, s_next, a_next):
        td_target = r + self.gamma * self.Q[s_next, a_next]
        td_error = td_target - self.Q[s, a]
        self.Q[s, a] += self.alpha * td_error

    def decay_epsilon(self):
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)


##### Step 2: Training loop

In [16]:
def train(agent, env, n_episodes=5000, max_steps=100):
    rewards = []
    for ep in range(n_episodes):
        s, _ = env.reset()                      # start new episode
        a = agent.choose_action(s)
        total_reward = 0

        for _ in range(max_steps):
            s_next, r, done, _, _ = env.step(a)
            a_next = agent.choose_action(s_next)
            agent.update(s, a, r, s_next, a_next)
            s, a = s_next, a_next
            total_reward += r
            if done:
                break

        agent.decay_epsilon()
        rewards.append(total_reward)
        if (ep+1) % 500 == 0:
            print(f"Episode {ep+1}/{n_episodes}  Average Reward: {np.mean(rewards[-500:]):.3f}")
            print(f"\nQ: {agent.Q}\n")
    return rewards

# Create environment
env = gym.make("FrozenLake-v1", is_slippery=False, render_mode="rgb_array")
agent = SARSAAgent(env)

# Train
training_rewards = train(agent, env)


Frozen Lake environment creation..
Observation space: 16
Action space: 4
SARSA hyperparameters are:
                  
alpha: 0.1
                  
epsilon: 1.0
                  
epsilon_decay: 0.995
                  
min_epsilon: 0.01

Episode 500/5000  Average Reward: 0.504

Q: [[3.96785127e-01 8.18525433e-01 2.13011424e-01 2.63628251e-01]
 [3.94231227e-01 0.00000000e+00 0.00000000e+00 4.50177988e-07]
 [5.05376861e-05 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [4.72636787e-01 8.59210511e-01 0.00000000e+00 2.72826168e-01]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [5.13927750e-01 0.00000000e+00 8.66388045e-01 4.03899255e-01]
 [1.87271736e-01 8.59589894e-01 4.84089058e-01 0.00000000e+00]
 [1.97328217e-01 8.97787957e-01 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.0000

##### Step 3: Implement the RL evaluation (animation) policy

In [17]:
def animate_policy(agent, env, fps=2):
    # Start virtual display
    display = Display(visible=0, size=(400, 400))
    display.start()

    n_rows, n_cols = agent.Q.shape
    states, actions = [], []
    s, _ = env.reset()
    states.append(s)

    # Rollout under greedy policy
    for _ in range(100):
        a = np.argmax(agent.Q[s])
        s, _, done, _, _ = env.step(a)
        states.append(s)
        if done:
            break

    # Setup plot
    fig, ax = plt.subplots()
    # ax.set_xlim(0, env.desc.shape[1])
    # ax.set_ylim(0, env.desc.shape[0])
    ax.set_xlim(0, n_cols)
    ax.set_ylim(0, n_rows)
    agent_dot, = ax.plot([], [], 'ro', ms=20)

    def init():
        agent_dot.set_data([], [])
        return agent_dot,

    def update(frame):
        # Convert state index to (row, col)
        row, col = divmod(states[frame], n_cols)
        agent_dot.set_data(col + 0.5, n_rows - row - 0.5)
        return agent_dot,

    anim = FuncAnimation(fig, update, init_func=init,
                         frames=len(states), interval=1000/fps, repeat=False)


    plt.close(fig)  # prevent static display
    display.stop()
    return anim

# Generate and display the animation in, e.g., a Jupyter notebook
# anim = animate_policy(agent, env)
# display.display(HTML(anim.to_jshtml()))


In [18]:
class EvaluateAgent:
  """
  Evaluate the SARSA RL agent using animation of the simulation runs
  """
  def __init__(self, agent, env, n_episodes=10, max_steps=100):
    """
    Constructor
    """
    self.agent = agent
    self.env = env
    self.n_episodes = n_episodes
    self.max_steps = max_steps
    self.display = Display(visible=0, size=(400, 400))
    self.display.start()
    self.states = []

    self.images = []

  def _evaluate(self):
    """
    Evaluate the agent
    """
    for ep in range(self.n_episodes):
      s, _ = self.env.reset()
      self.states.append(s)

      # Rollout under greedy policy
      for _ in range(self.max_steps):
          a = np.argmax(agent.Q[s])
          s, _, done, _, _ = self.env.step(a)
          self.images.append(self.env.render())
          self.states.append(s)
          if done:
              break

      self.env.close()


  def run(self):
    """
    Run the RL evaluation with animation
    """
    self._evaluate()
    Helpers.animateEnvironment(self.images)







In [19]:
evaluate = EvaluateAgent(agent, env)
evaluate.run()