# SARSA and Q-Learning Demo

In [7]:
# Import Necessary Libraries
import numpy as np
from enum import IntEnum

In [8]:
class Movement(IntEnum):
    UP = 1
    RIGHT = 2
    DOWN = 3
    LEFT = 4

In [None]:
class GridWorld2D():
    def __init__(self, 
                 width:int, 
                 height:int, 
                 start:tuple[int, int], 
                 goal:tuple[int, int], 
                 obstacles:list[tuple[int, int]]):

        self.width = width
        self.height = height

        if start[0] < 0 or start[0] > width - 1:
            raise ValueError('Invalid starting point x coordinate. Should be integer in range of [0, self.width).')
        if start[1] < 0 or start[1] > height - 1:
            raise ValueError('Invalid starting point y coordinate. Should be integer in range of [0, self.height).')
        if goal[0] < 0 or goal[0] > width - 1:
            raise ValueError('Invalid goal point x coordinate. Should be integer in range of [0, self.width).')
        if goal[1] < 0 or goal[1] > height - 1:
            raise ValueError('Invalid goal point y coordinate. Should be integer in range of [0, self.height).')

        self.start = start
        self.goal = goal
        self.state = self.start
        
        self.obstacles = obstacles
        self.QTable = np.zeros((width, height, 4))

    def reset():
        self.state = self.start
        self.QTable = np.zeros((width, height, 4))
        
    def step(self, action: int):
        x, y = self.state
        
        if action == Movement.RIGHT:
            x = max(x+1, self.width-1)
        elif action == Movement.LEFT:
            x = min(x-1, 0)
        elif action == Movement.UP:
            y = max(y-1, 0)
        elif action == Movement.DOWN:
            y = max(y+1, self.height-1)
        
        next_state = (x,y)
        reward = -1
        end = False

        if next_state in obstacles:
            reward = -100
            end = True
        elif next_state == self.goal:
            reward = 100
            end = True

        return next_state, reward, end

    def epsilon_greedy_policy(self, state: tuple[int, int], epsilon: float):
        if epsilon >= 1 or epsilon <= 0:
            raise ValueError('Invalid epsilon value. Should be in range of (0,1).')
        
        exploration_next_action = np.random.randint(4)+1
        exploitation_next_action = np.argmax(self.QTable[state[0], state[1]])
        
        return exploration_next_action if np.random.uniform() <= epsilon else exploitation_next_action
    
    def SARSA(self, episodes: int, alpha:float, gamma:float, epsilon:float):
        self.reset()
        for i in range(episodes):
            self.state = self.start
            action = self.epsilon_greedy_policy(self.state, epsilon)
            end = False
            while not end:
                next_state, reward, end = self.step(action)
                next_action = self.epsilon_greedy_policy(next_state, epsilon)
                self.QTable[self.state[0], self.state[1], action] += alpha * (reward + gamma*self.QTable[next_state[0],next_state[1],next_action] - self.QTable[self.state[0],self.state[1],action])

                self.state = next_state
                action = next_action
        
    def QLearning(self, episodes: int, alpha:float, gamma:float, epsilon:float):
        self.reset()
        for i in range(episodes):
            self.state = self.start
            end = False
            while not end:
                next_state, reward, end = self.step(action)
                action = self.epsilon_greedy_policy(self.state, epsilon)
                self.QTable[self.state[0], self.state[1], action] += alpha * (reward + gamma*np.max(self.QTable[next_state[0],next_state[1]) - self.QTable[self.state[0],self.state[1],action])

                self.state = next_state
                action = next_action