In [17]:
from typing import Union, Literal, List, Callable, Any, Tuple
import numpy as np

In [None]:
class Action:
    def __init__(self, name, next_state, reward, prob = 1):
        pass

In [69]:
class MDP:
    """
    Markov Decision Process
        horizon: number of time steps, excluding start and terminal
        state_space: all possible state names
        action_space: all possible action names
        actions: (time, state) => {action: (prob, next_state, reward)[]}, there must be one possible action at the horizon
    """
    def __init__(self, 
                 horizon: Union[int, Literal['infty']], 
                 state_space: List[str],
                 action_space: List[str],
                 actions # (time, state) => {action : (prob, next_state, reward)}[]
                 ):
        self.horizon = horizon
        self.state_space = state_space
        self.action_space = action_space
        self.actions = actions

    def get_actions(self, time: int, state: str):
        """Returns available actions for a given state and time."""
        return self.actions(time, state)
    
    def get_response(self, time: int, state: str, action: str):
        actions = self.actions(time, state)
        if action not in actions.keys():
            raise Exception(f"Valid actions are {actions.keys()}")
        return actions[action]
    
    def get_expected_reward(self, time: int, state: str, action: str):
        response = self.get_response(time, state, action)
        return np.sum([next_state[0]*next_state[2] for next_state in response])
    
    def evaluate_policy(self, policy: Callable[[int, str], Tuple[float, str]], 
                        time, state, 
                        method: Union[Literal['dp'], Literal['iter']]):
        
        # If at the horizon, return the reward
        if time == self.horizon:
            return list(self.actions(self.horizon, state).values())[0][0][2]
        
        current_policy = policy(time, state)

        eval = 0
        for action, prob in current_policy.items():
            response = self.get_response(time, state, action)
            eval += prob*np.sum([next_state[0]*(next_state[2]+self.evaluate_policy(policy, time+1, next_state[1])) for next_state in response])

        return eval
    
    def summary(self):
        pass


In [65]:
p = 0.8
N = 3

''' Returns possible action {action: (prob, next state, reward)} '''
def actions(time, state):
    if time < N:
        if state == 'free':
            return {
                'move': [(p, 'free', 0), (1-p, 'parked', 0)],
                'park': [(1, 'terminal', time)]
            }
        elif state == 'parked':
            return {
                'move': [(p, 'free', 0), (1-p, 'parked', 0)]
            }
    else:
        if state == 'free':
            return {
                'park': [(1, 'terminal', N)]
            }
        else:
            return {
                'move': [(1, 'terminal', 0)]
            }

mdp = MDP(N, ['free', 'parked', 'terminal'], ['move', 'park'], actions)

In [66]:
def policy(time: int, state: str):
    if state == 'free':
        return {'park': 0.5, 'move': 0.5}
    return {'move': 0.5}

In [67]:
mdp.actions(1, "parked")

{'move': [(0.8, 'free', 0), (0.19999999999999996, 'parked', 0)]}

In [68]:
mdp.evaluate_policy(policy, 1, "parked")

1.0

In [27]:
mdp.get_expected_reward(1, "free", "move")

0.0

In [16]:
from typing import Dict
import numpy as np

class PolicyIteration:
    def __init__(self, mdp: MDP, gamma: float = 1.0):
        self.mdp = mdp
        self.gamma = gamma
        self.value_function = {state: 0.0 for state in self.mdp.state_space}
        self.policy = {state: np.random.choice(self.mdp.action_space) for state in self.mdp.state_space if state != 'terminal'}

    def evaluate_policy(self, tolerance: float = 1e-6):
        """Evaluate the current policy until convergence."""
        while True:
            delta = 0
            for state in self.mdp.state_space:
                if state == 'terminal':
                    continue
                v = self.value_function[state]
                action = self.policy[state]
                transitions = self.mdp.get_transitions(0, state, action)
                new_value = sum(prob * (reward + self.gamma * self.value_function[next_state])
                                for prob, next_state, reward in transitions)
                self.value_function[state] = new_value
                delta = max(delta, abs(v - new_value))
            if delta < tolerance:
                break

    def policy_improvement(self):
        """Improve the current policy based on the value function."""
        policy_stable = True
        for state in self.mdp.state_space:
            if state == 'terminal':
                continue
            old_action = self.policy[state]
            action_values: Dict[str, float] = {}
            for action in self.mdp.get_actions(state, 0):
                transitions = self.mdp.get_transitions(0, state, action)
                action_values[action] = sum(prob * (reward + self.gamma * self.value_function[next_state])
                                            for prob, next_state, reward in transitions)
            best_action = max(action_values, key=action_values.get)
            self.policy[state] = best_action
            if best_action != old_action:
                policy_stable = False
        return policy_stable

    def iterate_policy(self):
        """Run policy iteration until the policy is stable."""
        iteration = 0
        while True:
            iteration += 1
            print(f"Iteration {iteration}: Policy Evaluation")
            self.evaluate_policy()
            print(f"Iteration {iteration}: Policy Improvement")
            if self.policy_improvement():
                print(f"Policy converged after {iteration} iterations.")
                break
        return self.policy, self.value_function


# Example usage
mdp = MDP(
    horizon=N,
    state_space=['free', 'parked', 'terminal'],
    action_space=['move', 'park'],
    transition=transition
)

pi = PolicyIteration(mdp)
optimal_policy, optimal_value_function = pi.iterate_policy()

print("Optimal Policy:")
for state, action in optimal_policy.items():
    print(f"State {state}: Action {action}")

print("\nOptimal Value Function:")
for state, value in optimal_value_function.items():
    print(f"State {state}: Value {value}")


Iteration 1: Policy Evaluation
Iteration 1: Policy Improvement
Policy converged after 1 iterations.
Optimal Policy:
State free: Action move
State parked: Action move

Optimal Value Function:
State free: Value 0.0
State parked: Value 0.0
State terminal: Value 0.0
