![Graph](./image/seven_states_directed_graph.png)

In [23]:
# SARSA implementation for a 7-state directed graph
import numpy as np
import random

In [24]:

# Define the reward matrix R
R = np.array([
    [-1, -1, -1,  0, -1, -1, -1],
    [-1, -1,  0, -1, -1, -1, -1],
    [-1,  0, -1,  0, -1,  0, -1],
    [ 0, -1,  0, -1,  0, -1, -1],
    [-1, -1, -1,  0, -1,  0, 100],
    [-1, -1,  0, -1,  0, -1, 100],
    [-1, -1, -1, -1, -1, -1, 100],
])

In [25]:
n_states = R.shape[0]
Q = np.zeros_like(R, dtype=float)

In [26]:
# Hyperparameters
alpha = 0.9      # Learning rate
gamma = 0.8      # Discount factor
epsilon = 0.1    # Epsilon for epsilon-greedy policy
episodes = 5000  # Number of training episodes, sarsa is slow to converge

In [27]:
# Epsilon-greedy action selection

def choose_action(state):
    valid_actions = [a for a in range(n_states) if R[state, a] >= 0]
    if random.random() < epsilon:
        return random.choice(valid_actions)
    else:
        q_vals = Q[state]
        max_q = np.max([q_vals[a] if a in valid_actions else -np.inf for a in range(n_states)])
        best_actions = [a for a in valid_actions if q_vals[a] == max_q]
        return random.choice(best_actions)

In [28]:

# SARSA training loop
for _ in range(episodes):
    state = random.randint(0, n_states - 1)
    action = choose_action(state)

    while state != 6:
        next_state = action
        next_action = choose_action(next_state)

        # SARSA update rule
        Q[state, action] += alpha * (
            R[state, action] + gamma * Q[next_state, next_action] - Q[state, action]
        )

        state, action = next_state, next_action

In [29]:
# Normalize Q-table for easier interpretation
Q_normalized = Q / Q.max() * 100

# Print normalized Q-table
import pandas as pd
pd.set_option("display.precision", 2)
print("\n✅ Learned SARSA Q-table (normalized):\n")
print(pd.DataFrame(Q_normalized))


✅ Learned SARSA Q-table (normalized):

       0      1      2      3      4      5      6
0   0.00   0.00   0.00  63.99   0.00   0.00    0.0
1   0.00   0.00  63.07   0.00   0.00   0.00    0.0
2   0.00  28.53   0.00  19.48   0.00  79.96    0.0
3  50.84   0.00  49.06   0.00  80.00   0.00    0.0
4   0.00   0.00   0.00  44.54   0.00  80.00  100.0
5   0.00   0.00  32.49   0.00  44.59   0.00  100.0
6   0.00   0.00   0.00   0.00   0.00   0.00    0.0


In [30]:


# Print optimal policy derived from Q-table
print("\n📌 Optimal policy from each state:")
for s in range(n_states):
    best_a = np.argmax(Q[s])
    print(f"From state {s} ➜ go to state {best_a}")



📌 Optimal policy from each state:
From state 0 ➜ go to state 3
From state 1 ➜ go to state 2
From state 2 ➜ go to state 5
From state 3 ➜ go to state 4
From state 4 ➜ go to state 6
From state 5 ➜ go to state 6
From state 6 ➜ go to state 0
