<a href="https://colab.research.google.com/github/datapirate09/Windy-Gridworld/blob/main/sarsa_algo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import copy
import random


def grid_to_tuple(grid):
    return tuple(tuple(row) for row in grid)

def initialize_action_values(states, actions):
    action_values = {}
    for row in range(len(states)):
        for col in range(len(states[0])):
            states_dup = copy.deepcopy(states)
            states_dup[row][col] = 1
            state_tuple = grid_to_tuple(states_dup)
            for action in actions:
                state_action_pair = (state_tuple, action)
                action_values[state_action_pair] = 0.0
    return action_values

def epsilon_greedy_policy(state_tuple, actions, action_values, epsilon=0.1):
    if random.random() < epsilon:
        return random.choice(actions)
    else:
        q_values = [(action, action_values.get((state_tuple, action), 0.0)) for action in actions]
        max_value = max(q_values, key=lambda x: x[1])[1]
        best_actions = [action for action, value in q_values if value == max_value]
        return random.choice(best_actions)

def is_terminal_state(state):
    destination_row = 3
    destination_col = 7
    if (state[destination_row][destination_col] == 1):
        return True
    return False

def get_reward(cur_row, cur_col):
    goal_row = 3
    goal_col = 7
    dist = abs(cur_row - goal_row) + abs(cur_col - goal_col)
    if cur_row == goal_row and cur_col == goal_col:
        return 1.0
    return -dist * 0.01 - 0.01


def get_next_coordinates(cur_row, cur_col, action, max_rows=7, max_cols=10):
    if action == 0:  # left
        next_row = cur_row
        next_col = max(0, cur_col - 1)
    elif action == 1:  # down
        next_row = min(max_rows - 1, cur_row + 1)
        next_col = cur_col
    elif action == 2:  # right
        next_row = cur_row
        next_col = min(max_cols - 1, cur_col + 1)
    elif action == 3:  # up
        next_row = max(0, cur_row - 1)
        next_col = cur_col
    elif action == 4:  # down-left
        next_row = min(max_rows - 1, cur_row + 1)
        next_col = max(0, cur_col - 1)
    elif action == 5:  # up-left
        next_row = max(0, cur_row - 1)
        next_col = max(0, cur_col - 1)
    elif action == 6:  # up-right
        next_row = max(0, cur_row - 1)
        next_col = min(max_cols - 1, cur_col + 1)
    elif action == 7:  # down-right
        next_row = min(max_rows - 1, cur_row + 1)
        next_col = min(max_cols - 1, cur_col + 1)
    else:
        next_row, next_col = cur_row, cur_col
    return next_row, next_col


def get_after_wind_coordinates(row_expected, col_expected, grid):
    max_rows = len(grid)
    max_cols = len(grid[0])
    if 0 <= row_expected < max_rows and 0 <= col_expected < max_cols:
        wind_strength = grid[row_expected][col_expected]
        row_actual = max(0, row_expected - wind_strength)
    else:
        row_actual = row_expected
    col_actual = col_expected
    return row_actual, col_actual


def sarsa_algo(states, actions, action_values, source_row, source_col):
    grid_world = []
    for i in range(7):
        grid_world.append([0, 0, 0, 1, 1, 1, 2, 2, 1, 0])
    epsilon = 0.1

    for episode in range(10000):
        epsilon = max(0.01, epsilon * 0.9995)
        state_episode = copy.deepcopy(states)
        state_episode[source_row][source_col] = 1
        state_tuple = grid_to_tuple(state_episode)
        action = epsilon_greedy_policy(state_tuple, actions, action_values, epsilon)
        steps_per_episode = 0
        cur_row = source_row
        cur_col = source_col
        max_steps = 1000

        while(not is_terminal_state(state_episode) and steps_per_episode < max_steps):
            steps_per_episode += 1
            next_row_expected, next_col_expected = get_next_coordinates(cur_row, cur_col, action)
            next_row_actual, next_col_actual = get_after_wind_coordinates(next_row_expected, next_col_expected, grid_world)
            next_state = copy.deepcopy(states)
            next_state[next_row_actual][next_col_actual] = 1
            immediate_reward = get_reward(next_row_actual, next_col_actual)
            next_state_tuple = grid_to_tuple(next_state)
            next_action = epsilon_greedy_policy(next_state_tuple, actions, action_values, epsilon)
            next_state_action_pair = (next_state_tuple, next_action)

            if (is_terminal_state(next_state)):
                action_values[(state_tuple, action)] = action_values[(state_tuple, action)] + 0.2 * (immediate_reward - action_values[(state_tuple, action)])
                if episode % 100 == 0:
                  print(f"Episode {episode}: Steps {steps_per_episode+1}")
                break

            action_values[(state_tuple, action)] = action_values[(state_tuple, action)] + 0.2 * (immediate_reward + (0.9 * action_values[next_state_action_pair]) - action_values[(state_tuple, action)])
            state_episode = next_state
            state_tuple = next_state_tuple
            action = next_action
            cur_row = next_row_actual
            cur_col = next_col_actual
        if steps_per_episode >= max_steps:
            print(f"Episode {episode}: Reached maximum steps without finding goal")


def print_policy(action_values, states, actions):
    empty_states = [[0 for _ in range(len(states[0]))] for _ in range(len(states))]
    direction_symbols = {
        0: "←",
        1: "↓",
        2: "→",
        3: "↑",
        4: "↙",
        5: "↖",
        6: "↗",
        7: "↘",
    }
    directions = [[0 for _ in range(10)] for _ in range(7)]

    for i in range(7):
        for j in range(10):
            empty_states[i][j] = 1
            state_tuple = grid_to_tuple(empty_states)

            best_value_function = float('-inf')
            best_action = -1
            for a in actions:
                state_value_function = action_values.get((state_tuple, a), 0)
                if state_value_function > best_value_function:
                    best_action = a
                    best_value_function = state_value_function
            directions[i][j] = best_action
            empty_states[i][j] = 0

    policy_symbols = [[direction_symbols.get(directions[i][j], "·") for j in range(10)] for i in range(7)]

    print("\nLearned Policy:")
    for i, row in enumerate(policy_symbols):
        row_str = ""
        for j, symbol in enumerate(row):
            if i == 3 and j == 7:
                row_str += "G  "
            else:
                row_str += symbol + "  "
        print(row_str)

grid_world = []
for i in range(7):
    grid_world.append([0, 0, 0, 1, 1, 1, 2, 2, 1, 0])

states = [[0]*10 for _ in range(7)]
actions = [0, 1, 2, 3, 4, 5, 6, 7]  # 0-left, 1-down, 2-right, 3-up, 4-down-left, 5-up-left, 6-up-right, 7-down-right

action_values = initialize_action_values(states, actions)
sarsa_algo(states, actions, action_values, 3, 0)
print_policy(action_values, states, actions)

Episode 0: Reached maximum steps without finding goal
Episode 100: Steps 21
Episode 200: Steps 16
Episode 300: Steps 17
Episode 400: Steps 16
Episode 500: Steps 18
Episode 600: Steps 15
Episode 700: Steps 17
Episode 800: Steps 15
Episode 900: Steps 16
Episode 1000: Steps 15
Episode 1100: Steps 15
Episode 1200: Steps 15
Episode 1300: Steps 15
Episode 1400: Steps 15
Episode 1500: Steps 17
Episode 1600: Steps 15
Episode 1700: Steps 18
Episode 1800: Steps 18
Episode 1900: Steps 15
Episode 2000: Steps 15
Episode 2100: Steps 15
Episode 2200: Steps 16
Episode 2300: Steps 15
Episode 2400: Steps 15
Episode 2500: Steps 15
Episode 2600: Steps 15
Episode 2700: Steps 15
Episode 2800: Steps 15
Episode 2900: Steps 15
Episode 3000: Steps 15
Episode 3100: Steps 15
Episode 3200: Steps 17
Episode 3300: Steps 15
Episode 3400: Steps 15
Episode 3500: Steps 15
Episode 3600: Steps 15
Episode 3700: Steps 15
Episode 3800: Steps 15
Episode 3900: Steps 15
Episode 4000: Steps 15
Episode 4100: Steps 15
Episode 4200