In [2]:
import numpy as np

# Define the MDP
states = ['A', 'B', 'C', 'D']
actions = ['left', 'right']

# Transition probabilities P(s'|s, a)
P = {
    'A': {'left': {'A': 1.0}, 'right': {'B': 1.0}},
    'B': {'left': {'A': 1.0}, 'right': {'C': 1.0}},
    'C': {'left': {'B': 1.0}, 'right': {'D': 1.0}},
    'D': {'left': {'C': 1.0}, 'right': {'D': 1.0}}
}

# Rewards R(s, a)
R = {
    'A': {'left': 0, 'right': 0},
    'B': {'left': 0, 'right': 0},
    'C': {'left': 0, 'right': 1},  # reward for moving right from C to D
    'D': {'left': 0, 'right': 0}
}

gamma = 0.9  # discount factor
theta = 1e-4  # convergence threshold


def policy_evaluation(policy, V):
    """Evaluate a given policy"""
    while True:
        delta = 0
        for s in states:
            v = V[s]
            a = policy[s]
            V[s] = sum(P[s][a][s_next] * (R[s][a] + gamma * V[s_next]) for s_next in P[s][a])
            delta = max(delta, abs(v - V[s]))
        if delta < theta:
            break
    return V


def policy_improvement(V, policy):
    """Improve policy based on value function"""
    policy_stable = True
    for s in states:
        old_action = policy[s]
        action_values = {}
        for a in actions:
            action_values[a] = sum(P[s][a][s_next] * (R[s][a] + gamma * V[s_next]) for s_next in P[s][a])
        best_action = max(action_values, key=action_values.get)
        policy[s] = best_action
        if old_action != best_action:
            policy_stable = False
    return policy, policy_stable


def policy_iteration():
    """Perform policy iteration (evaluation + improvement)"""
    # Initialize policy and value function locally here ✅
    policy = {s: np.random.choice(actions) for s in states}
    V = {s: 0 for s in states}

    iteration = 0
    while True:
        iteration += 1
        print(f"\nIteration {iteration}: Policy Evaluation & Improvement")
        V = policy_evaluation(policy, V)
        policy, stable = policy_improvement(V, policy)
        print("Values:", {s: round(V[s], 3) for s in states})
        print("Policy:", policy)
        if stable:
            print("\n✅ Optimal policy found!")
            break
    return policy, V


# Run policy iteration
optimal_policy, optimal_values = policy_iteration()

print("\nFinal Optimal Values and Policy:")
for s in states:
    print(f"State {s}: V = {optimal_values[s]:.4f}, π*(s) = {optimal_policy[s]}")



Iteration 1: Policy Evaluation & Improvement
Values: {'A': 0.81, 'B': 0.9, 'C': 1.0, 'D': 0.0}
Policy: {'A': 'right', 'B': 'right', 'C': 'right', 'D': 'left'}

Iteration 2: Policy Evaluation & Improvement
Values: {'A': 4.263, 'B': 4.736, 'C': 5.263, 'D': 4.737}
Policy: {'A': 'right', 'B': 'right', 'C': 'right', 'D': 'left'}

✅ Optimal policy found!

Final Optimal Values and Policy:
State A: V = 4.2628, π*(s) = right
State B: V = 4.7365, π*(s) = right
State C: V = 5.2628, π*(s) = right
State D: V = 4.7365, π*(s) = left
