In [1]:
import numpy as np
import random

class SARSAAgent:
    def __init__(self, state_size, action_size, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.state_size = state_size
        self.action_size = action_size
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = np.zeros((state_size, action_size))  # Q-value table

    def choose_action(self, state): 
        """ Epsilon-greedy policy for action selection """
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.action_size)  # Explore #choose any number from 0 to 4, inclusive
        else:
            return np.argmax(self.q_table[state])  # Exploit

    def update(self, state, action, reward, next_state, next_action):
        """ SARSA Q-value update rule. On-Policy """
        td_target = reward + self.gamma * self.q_table[next_state, next_action]
        td_error = td_target - self.q_table[state, action]
        self.q_table[state, action] += self.alpha * td_error  # Update rule

# Example Usage
env_states = 10
env_actions = 4
agent = SARSAAgent(env_states, env_actions)

# Simulating a learning step
state = 0
action = agent.choose_action(state)
reward = -1  # Assume a reward from the environment
next_state = 0  # Next state after taking action
### might want to create a  state-index mapping for functions  `choose_action` and `update`
next_action = agent.choose_action(next_state)

agent.update(state, action, reward, next_state, next_action)


In [2]:
states = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10"] #note that I changed the transition probabilities for a while, specifically state "S1" is represented by 0, ... , state "S10" is represented by 9
actions = ["Continue", "Re-route", "Adjust Speed", "Choose Alternate"]

transition_probabilities = {
    0: {
        "Continue": {1: 0.7, 2: 0.3},
        "Re-route": {5: 1.0},
        "Adjust Speed": {0: 1.0},
        "Choose Alternate": {3: 0.6, 4: 0.4}
    },
    1: {
        "Continue": {2: 0.8, 3: 0.2},
        "Re-route": {5: 0.7, 4: 0.3},
        "Adjust Speed": {1: 1.0},
        "Choose Alternate": {6: 0.5, 3: 0.5}
    },
    2: {
        "Continue": {4: 0.6, 6: 0.4},
        "Re-route": {7: 0.6, 5: 0.4},
        "Adjust Speed": {2: 1.0},
        "Choose Alternate": {3: 0.5, 8: 0.5}
    },
    3: {
        "Continue": {4: 0.7, 5: 0.3},
        "Re-route": {7: 0.5, 8: 0.5},
        "Adjust Speed": {3: 1.0},
        "Choose Alternate": {6: 0.5, 9: 0.5}
    },
    4: {
        "Continue": {5: 0.8, 6: 0.2},
        "Re-route": {8: 0.7, 9: 0.3},
        "Adjust Speed": {4: 1.0},
        "Choose Alternate": {3: 0.5, 7: 0.5}
    },
    5: {
        "Continue": {6: 0.9, 7: 0.1},
        "Re-route": {5: 1.0},
        "Adjust Speed": {5: 1.0},
        "Choose Alternate": {3: 0.3, 9: 0.7}
    },
    6: {
        "Continue": {9: 0.95, 5: 0.05},
        "Re-route": {3: 0.7, 8: 0.3},
        "Adjust Speed": {6: 1.0},
        "Choose Alternate": {5: 0.3, 2: 0.7}
    },
    7: {
        "Continue": {8: 0.8, 5: 0.2},
        "Re-route": {6: 0.6, 9: 0.4},
        "Adjust Speed": {7: 1.0},
        "Choose Alternate": {6: 0.3, 3: 0.7}
    },
    8: {
        "Continue": {9: 1.0},
        "Re-route": {6: 0.5, 3: 0.5},
        "Adjust Speed": {8: 1.0},
        "Choose Alternate": {2: 0.6, 6: 0.4}
    },
    9: {
        "Continue": {9: 0.0},
        "Re-route": {9: 0.0},
        "Adjust Speed": {9: 0.0},
        "Choose Alternate": {9: 0.0}
    }
}

In [3]:
class PedestrianPaths:
    def __init__(self):
        self.state = 0  # Start state ### might wanna create a state-index mapping later
        self.goal = 9
        self.states = 10
        self.actions = ["Continue", "Re-route", "Adjust Speed", "Choose Alternate"] #4 possible actions 
        self.transition_probabilities = transition_probabilities

    def step(self, action): #string
        """ Move in the environment based on action """
        #action-index mapping
        if action == 0:  
            action = "Continue"
        elif action == 1:
            action = "Re-route"
        elif action == 2:
            action = "Adjust Speed"
        elif action == 3:
            action = "Choose Alternate"

        possible_states = self.transition_probabilities[self.state][action].keys()
        if len(possible_states) == 1:
            self.state = list(possible_states)[0]
        else:
            probabilities = self.transition_probabilities[self.state][action].values()
            higher_probability = list(probabilities)[0]
            if np.random.rand() < higher_probability:
                self.state = list(possible_states)[0]
            else:
                self.state = list(possible_states)[1]
        
        # Define rewards
        reward = 1 if self.state == self.goal else -0.1
        done = self.state == self.goal  # Episode ends when goal is reached
        return self.state, reward, done

    def reset(self):
        self.state = 0
        return self.state

# Training the SARSA Agent
env = PedestrianPaths()
env_states = 10
env_actions = 4
agent = SARSAAgent(env_states, env_actions)

episodes = 100
for episode in range(episodes):
    state = env.reset()
    action = agent.choose_action(state) #first action to choose
    while True:
        
        next_state, reward, done = env.step(action)
        next_action = agent.choose_action(next_state) if not done else None

        # Update Q-table using SARSA
        agent.update(state, action, reward, next_state, next_action if not done else 0)
        
        state, action = next_state, next_action
        
        if done:
            break

# Print learned Q-values
print(agent.q_table)


[[-4.53978186e-02  1.12765220e-02  1.53519339e-01  7.00748711e-01]
 [-5.80372530e-03  8.00241428e-02  0.00000000e+00  0.00000000e+00]
 [ 3.19158994e-02 -1.00000000e-02 -1.90000000e-02 -1.00000000e-02]
 [ 1.40086500e-01 -1.58593925e-02 -1.90000000e-02  8.75136337e-01]
 [-1.18819000e-02  8.17675332e-01 -1.90000000e-02 -1.42865290e-02]
 [-1.00000000e-02 -1.90000000e-02 -1.90000000e-02  3.62873495e-01]
 [ 9.72187161e-01  6.54202213e-02  0.00000000e+00 -1.00000000e-02]
 [-1.00000000e-04  4.40829800e-02  0.00000000e+00  1.48572587e-02]
 [ 9.41850263e-01  0.00000000e+00  0.00000000e+00  6.63519755e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]
