In [1]:
import random

# actions
actions = ["LEFT", "RIGHT"]

# Q-table initialization
Q = {s: {a: 0.0 for a in actions} for s in range(5)}

# hyperparameters
alpha = 0.1
gamma = 0.9
epsilon = 0.2


# epsilon-greedy policy
def choose_action(state):
    if random.random() < epsilon:
        return random.choice(actions)
    return max(Q[state], key=Q[state].get)


# environment step
def step(state, action):
    if action == "RIGHT":
        next_state = min(state + 1, 4)
    else:
        next_state = max(state - 1, 0)

    reward = 10 if next_state == 4 else -1
    done = next_state == 4
    return next_state, reward, done


# -------- SARSA Training --------
for episode in range(200):

    state = 0
    action = choose_action(state)   # ✅ SARSA: choose A_t first

    done = False
    while not done:

        # take action
        next_state, reward, done = step(state, action)

        # ✅ SARSA: choose next action A_{t+1} using same policy
        next_action = choose_action(next_state)

        # ✅ SARSA update
        Q[state][action] += alpha * (
            reward
            + gamma * Q[next_state][next_action]
            - Q[state][action]
        )

        # move forward
        state = next_state
        action = next_action


print("Learned Q-table:")
for s in Q:
    print(s, Q[s])

Learned Q-table:
0 {'LEFT': 1.1581729593051227, 'RIGHT': 3.1891998575266864}
1 {'LEFT': 1.1012098107791095, 'RIGHT': 5.128135685112724}
2 {'LEFT': 1.7336192834958057, 'RIGHT': 6.960986623309074}
3 {'LEFT': 4.030154111596354, 'RIGHT': 9.99999999294492}
4 {'LEFT': 0.0, 'RIGHT': 0.0}


In [2]:
## TODO: 搞清楚是怎么一步一步算出来的