### **SARSA in FrozenLake**

SARSA Algorithm for an agent in the Frozenlake environment. <br> <br>
works for 4x4 and 8x8  layouts <br>
Careful you need to increase (significantly) the number of episodes if you use it for 8x8<br>
You can use this example to test is_slippery = True and understand stochastic environments

In [1]:
import numpy as np
import gymnasium as gym
import session_info

In [21]:
# SARSA parameters
alpha = 0.05  # Learning rate
gamma = 0.9  # Discount factor

epsilon = 1.0     # starting epsilon  
epsilon_min = 0.03
epsilon_decay_rate = 0.9999 
episodes = 100000      # Number of episodes

In [22]:
env = gym.make('FrozenLake-v1', map_name="4x4", is_slippery=True)

height = env.unwrapped.nrow
width = env.unwrapped.ncol


# Initialize Q-table
Q = np.zeros((env.observation_space.n, env.action_space.n))
   
def epsilon_greedy_policy(state, epsilon):  # Explotaition vs Exploration
    if np.random.random() < epsilon:
        return env.action_space.sample()
    else:
        return np.argmax(Q[state,:])

# SARSA algorithm

episode_reward = []

for episode in range(episodes):
    state, _ = env.reset()
    action = epsilon_greedy_policy(state, epsilon)
    done = False
    total_reward = 0
    while not done:
        
        next_state, reward, terminated, truncated, _= env.step(action)
        total_reward += reward
        next_action = epsilon_greedy_policy(next_state, epsilon)
        
        # SARSA update
        Q[state, action] += alpha * (reward + gamma * Q[next_state, next_action] - Q[state, action])
        
        state = next_state
        action = next_action
        done = terminated

        if done:
            episode_reward.append(total_reward)
            if episode % 100 == 0:
                avg_reward = np.mean(episode_reward[-100:])
                print(f"Episode {episode}, Average Reward: {avg_reward}")

        
    if epsilon > epsilon_min:
        epsilon = max(epsilon_min, epsilon * epsilon_decay_rate)

Episode 0, Average Reward: 0.0
Episode 100, Average Reward: 0.04
Episode 200, Average Reward: 0.0
Episode 300, Average Reward: 0.0
Episode 400, Average Reward: 0.03
Episode 500, Average Reward: 0.03
Episode 600, Average Reward: 0.0
Episode 700, Average Reward: 0.03
Episode 800, Average Reward: 0.01
Episode 900, Average Reward: 0.02
Episode 1000, Average Reward: 0.0
Episode 1100, Average Reward: 0.03
Episode 1200, Average Reward: 0.02
Episode 1300, Average Reward: 0.04
Episode 1400, Average Reward: 0.02
Episode 1500, Average Reward: 0.03
Episode 1600, Average Reward: 0.03
Episode 1700, Average Reward: 0.01
Episode 1800, Average Reward: 0.01
Episode 1900, Average Reward: 0.03
Episode 2000, Average Reward: 0.02
Episode 2100, Average Reward: 0.0
Episode 2200, Average Reward: 0.02
Episode 2300, Average Reward: 0.04
Episode 2400, Average Reward: 0.01
Episode 2500, Average Reward: 0.02
Episode 2600, Average Reward: 0.02
Episode 2700, Average Reward: 0.01
Episode 2800, Average Reward: 0.03
Epi

Episode 24100, Average Reward: 0.29
Episode 24200, Average Reward: 0.32
Episode 24300, Average Reward: 0.39
Episode 24400, Average Reward: 0.46
Episode 24500, Average Reward: 0.37
Episode 24600, Average Reward: 0.42
Episode 24700, Average Reward: 0.24
Episode 24800, Average Reward: 0.4
Episode 24900, Average Reward: 0.39
Episode 25000, Average Reward: 0.28
Episode 25100, Average Reward: 0.29
Episode 25200, Average Reward: 0.4
Episode 25300, Average Reward: 0.48
Episode 25400, Average Reward: 0.35
Episode 25500, Average Reward: 0.4
Episode 25600, Average Reward: 0.46
Episode 25700, Average Reward: 0.42
Episode 25800, Average Reward: 0.44
Episode 25900, Average Reward: 0.22
Episode 26000, Average Reward: 0.18
Episode 26100, Average Reward: 0.37
Episode 26200, Average Reward: 0.31
Episode 26300, Average Reward: 0.39
Episode 26400, Average Reward: 0.3
Episode 26500, Average Reward: 0.37
Episode 26600, Average Reward: 0.38
Episode 26700, Average Reward: 0.46
Episode 26800, Average Reward: 0

Episode 47400, Average Reward: 0.58
Episode 47500, Average Reward: 0.44
Episode 47600, Average Reward: 0.49
Episode 47700, Average Reward: 0.44
Episode 47800, Average Reward: 0.41
Episode 47900, Average Reward: 0.32
Episode 48000, Average Reward: 0.27
Episode 48100, Average Reward: 0.32
Episode 48200, Average Reward: 0.35
Episode 48300, Average Reward: 0.4
Episode 48400, Average Reward: 0.4
Episode 48500, Average Reward: 0.42
Episode 48600, Average Reward: 0.43
Episode 48700, Average Reward: 0.41
Episode 48800, Average Reward: 0.66
Episode 48900, Average Reward: 0.53
Episode 49000, Average Reward: 0.57
Episode 49100, Average Reward: 0.53
Episode 49200, Average Reward: 0.53
Episode 49300, Average Reward: 0.4
Episode 49400, Average Reward: 0.43
Episode 49500, Average Reward: 0.28
Episode 49600, Average Reward: 0.51
Episode 49700, Average Reward: 0.56
Episode 49800, Average Reward: 0.5
Episode 49900, Average Reward: 0.52
Episode 50000, Average Reward: 0.38
Episode 50100, Average Reward: 0

Episode 70300, Average Reward: 0.55
Episode 70400, Average Reward: 0.4
Episode 70500, Average Reward: 0.32
Episode 70600, Average Reward: 0.19
Episode 70700, Average Reward: 0.51
Episode 70800, Average Reward: 0.42
Episode 70900, Average Reward: 0.3
Episode 71000, Average Reward: 0.22
Episode 71100, Average Reward: 0.5
Episode 71200, Average Reward: 0.55
Episode 71300, Average Reward: 0.46
Episode 71400, Average Reward: 0.47
Episode 71500, Average Reward: 0.48
Episode 71600, Average Reward: 0.46
Episode 71700, Average Reward: 0.45
Episode 71800, Average Reward: 0.79
Episode 71900, Average Reward: 0.5
Episode 72000, Average Reward: 0.47
Episode 72100, Average Reward: 0.7
Episode 72200, Average Reward: 0.56
Episode 72300, Average Reward: 0.46
Episode 72400, Average Reward: 0.32
Episode 72500, Average Reward: 0.67
Episode 72600, Average Reward: 0.42
Episode 72700, Average Reward: 0.55
Episode 72800, Average Reward: 0.61
Episode 72900, Average Reward: 0.5
Episode 73000, Average Reward: 0.4

Episode 93400, Average Reward: 0.56
Episode 93500, Average Reward: 0.63
Episode 93600, Average Reward: 0.61
Episode 93700, Average Reward: 0.47
Episode 93800, Average Reward: 0.5
Episode 93900, Average Reward: 0.55
Episode 94000, Average Reward: 0.48
Episode 94100, Average Reward: 0.27
Episode 94200, Average Reward: 0.51
Episode 94300, Average Reward: 0.47
Episode 94400, Average Reward: 0.56
Episode 94500, Average Reward: 0.45
Episode 94600, Average Reward: 0.5
Episode 94700, Average Reward: 0.64
Episode 94800, Average Reward: 0.61
Episode 94900, Average Reward: 0.51
Episode 95000, Average Reward: 0.35
Episode 95100, Average Reward: 0.5
Episode 95200, Average Reward: 0.49
Episode 95300, Average Reward: 0.72
Episode 95400, Average Reward: 0.59
Episode 95500, Average Reward: 0.37
Episode 95600, Average Reward: 0.57
Episode 95700, Average Reward: 0.44
Episode 95800, Average Reward: 0.42
Episode 95900, Average Reward: 0.42
Episode 96000, Average Reward: 0.43
Episode 96100, Average Reward: 

In [23]:
# Test the learned policy
def test_policy(n_episodes=100):
    successes = 0
    for _ in range(n_episodes):
        state, _ = env.reset()
        done = False
        while not done:
            action = np.argmax(Q[state, :])
            state, reward, done, _ , _ = env.step(action)
            if reward == 1:
                successes += 1
    return successes / n_episodes

success_rate = test_policy()
print(f"Success rate: {success_rate:.2%}")

# Display the learned Q-table
print("\nLearned Q-table:")
print('[ ','←', '↓', '→', '↑', ' ]')
print()

Success rate: 60.00%

Learned Q-table:
[  ← ↓ → ↑  ]



In [24]:
# Print learned policy
print("Learned Policy:")
print("===============")

policy = np.argmax(Q, axis=1)
policy_symbols = ['←', '↓', '→', '↑', 'S', 'G', 'H']  # 0-3 = actions; 4+ = special tiles

desc = env.unwrapped.desc
height, width = desc.shape

# Mark special tiles in the policy
for i in range(height):
    for j in range(width):
        idx = i * width + j
        tile = desc[i][j]
        if tile == b'H':
            policy[idx] = 6  # Hole
        elif tile == b'G':
            policy[idx] = 5  # Goal
        elif tile == b'S':
            policy[idx] = 4  # Start

# Print policy with symbols
for i in range(height):
    for j in range(width):
        idx = i * width + j
        print(policy_symbols[policy[idx]], end=' ')
    print()

# Print value function
print("\nValue Function:")
print("===================")
V = np.max(Q, axis=1)
for i in range(height):
    for j in range(width):
        idx = i * width + j
        print(f"{V[idx]:.2f}", end=' ')
    print()


Learned Policy:
S ↑ → ↑ 
← H → H 
↑ ↓ ← H 
H → → G 

Value Function:
0.07 0.05 0.05 0.04 
0.09 0.00 0.07 0.00 
0.13 0.23 0.26 0.00 
0.00 0.36 0.72 0.00 


In [25]:
print(Q)

[[0.07040552 0.05362913 0.05905908 0.04898796]
 [0.03216473 0.03749068 0.03987198 0.04788849]
 [0.04660429 0.04576061 0.0494567  0.04596121]
 [0.03548499 0.02959899 0.0263902  0.03947054]
 [0.09359795 0.06681629 0.06210963 0.03430018]
 [0.         0.         0.         0.        ]
 [0.05095785 0.05102267 0.06939275 0.02124467]
 [0.         0.         0.         0.        ]
 [0.05765238 0.08247041 0.09045897 0.12983188]
 [0.14526121 0.22881533 0.17502163 0.12676798]
 [0.26267495 0.17131603 0.18857097 0.0790739 ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.18157252 0.25071402 0.35813658 0.24043191]
 [0.38637847 0.49320144 0.72049254 0.44314665]
 [0.         0.         0.         0.        ]]


In [28]:
# Test the learned policy
env = gym.make('FrozenLake-v1', map_name="4x4", is_slippery=True, render_mode='human')
state, _ = env.reset()
state = int(state)
done = False
total_reward = 0

while not done:
    action = np.argmax(Q[state, :])
    state, reward, terminated, truncated, _ = env.step(action)
    state = int(state)
    done = terminated or truncated
    total_reward += reward
    env.render()

print(f"\nTotal reward: {total_reward}")


Total reward: 1.0


In [27]:
env.close()

In [8]:
session_info.show(html=False)

-----
gymnasium           1.0.0
numpy               1.26.4
session_info        1.0.0
-----
IPython             8.26.0
jupyter_client      8.6.2
jupyter_core        5.7.2
-----
Python 3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0]
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
-----
Session information updated at 2025-05-16 13:25
