## SARSA
##### *Mintae Kim, Hybrid Robotics Lab, UC Berkeley*
##### 07/16/2023

### Import Packages

In [None]:
import numpy as np
import random
from collections import defaultdict
from environment import Env

### Implementation of SARSA
Since SARSA is a q function-based algorithm, we only need an agent class to implement it. Agent would have a policy within itself.

```python
class SARSAgent:
    def __init__(self, env):
        pass
    def learn(self, state, action, reward, next_state, next_action):
        pass
    def get_action(self, state):
        pass
```

In [None]:
class SARSAgent:
    def __init__(self, actions):
        self.actions = actions
        self.step_size = 0.01
        self.discount_factor = 0.9
        self.epsilon = 0.1
        # Initialize q function table with zeros
        self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])

    # Update q function with <s, a, r, s', a'>
    def learn(self, state, action, reward, next_state, next_action):
        state, next_state = str(state), str(next_state)
        current_q = self.q_table[state][action]
        next_state_q = self.q_table[next_state][next_action]
        td = reward + self.discount_factor * next_state_q - current_q
        new_q = current_q + self.step_size * td
        self.q_table[state][action] = new_q

    # Return action with $\epsilon$-greedy policy
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            # Random action
            action = np.random.choice(self.actions)
        else:
            # Greedy action
            state = str(state)
            q_list = self.q_table[state]
            action = arg_max(q_list)
        return action


# Return best action based on the value of q function
def arg_max(q_list):
    max_idx_list = np.argwhere(q_list == np.amax(q_list))
    max_idx_list = max_idx_list.flatten().tolist()
    return random.choice(max_idx_list)

### Execution

In [None]:
if __name__ == "__main__":
    env = Env()
    agent = SARSAgent(actions=list(range(env.n_actions)))

    for episode in range(1000):
        # Initialize game environment and state
        state = env.reset()
        # Choose action based on current state
        action = agent.get_action(state)

        while True:
            env.render()

            # After taking the action, get next state, reward, and whether the episode is done or not
            next_state, reward, done = env.step(action)
            # Choose next action at next state
            next_action = agent.get_action(next_state)
            # Update q function with <s,a,r,s',a'>
            agent.learn(state, action, reward, next_state, next_action)

            state = next_state
            action = next_action

            # Displacy every q function on the screen
            env.print_value_all(agent.q_table)

            if done:
                break