In [1]:
import numpy as np
from tqdm import tqdm
import random
import math

Task 3 (2 points) Compute the optimal policy using policy iteration. 
How many iterations are necessary to achieve convergence?

Document:
- For which states the action is “do nothing” vs. “do maintenance”. 
- The total expected discounted cost per state.
- The number of iterations needed for convergence.

## States, action space and cost (reward)

In [2]:
# degradation thresholds
xi = {1: 15, 2: 30, 3: 50}

# State space
S = []
for T in [1, 2, 3]:
    for d in range(xi[T]+1):
        S.append((T, d))

# Get total number of states
num_states = len(S)
print(f"There are {len(S)} states in our MDP.")

# Action space
A = [0, 1]  # 0: "Do nothing", 1: "Do maintenance"

# Cost function C[s][a] where s is current state, a is action
C = np.zeros((num_states, 2))


for idx_s, s in enumerate(S):
    T, d = s

    if d == xi[T]:
        C[idx_s][0] = float("inf")
        C[idx_s][1] = 5
    else:
        C[idx_s][0] = 0
        C[idx_s][1] = 1

There are 98 states in our MDP.


## Transition probabilities

In [3]:
lambda_param = 4
pi = 0.5

def zero_inflated_poisson_pmf(k, lambda_param=lambda_param, pi=pi):
    if k == 0:
        return pi + (1 - pi) * np.exp(-lambda_param)
    else:
        return (1 - pi) * (np.power(lambda_param, k) * np.exp(-lambda_param)) / math.factorial(k)

In [4]:
# Calculate the maximum k needed for zero-inflated Poisson
def find_max_k(lambda_param=lambda_param, pi=pi, epsilon=np.finfo(float).eps):
    k = 0
    while True:
        # Calculate probability for this k
        prob = zero_inflated_poisson_pmf(k, lambda_param, pi)
        
        # If probability is below floating-point precision, we've found our cutoff
        if prob < epsilon:
            return k
        k += 1

# Calculate the maximum k needed once
max_k = find_max_k()
print(f"Maximum k needed for sufficient precision: {max_k}")

Maximum k needed for sufficient precision: 30


In [5]:
# P[a][s][s'] where a is action, s is current state, s' is next state
P = np.zeros((2, num_states, num_states))

# Calculate P0 (Do nothing)
for idx_s, s in enumerate(S):
    T, d = s
    
    # Cannot do nothing in failed state
    if d == xi[T]:
        continue
        
    for idx_s_prime, s_prime in enumerate(S):
        T_prime, d_prime = s_prime
        
        # Type cannot change under "do nothing"
        if T != T_prime:
            continue
            
        # Calculate transition probability based on degradation increase
        if d <= d_prime < xi[T]:
            P[0][idx_s][idx_s_prime] = zero_inflated_poisson_pmf(d_prime - d)
        elif d < xi[T] and d_prime == xi[T]:
            # Transition to failed state (cumulative probability of large increases)
            cumulative_prob = 0
            for k in range(xi[T] - d, max_k):
                cumulative_prob += zero_inflated_poisson_pmf(k)
            P[0][idx_s][idx_s_prime] = cumulative_prob

# Calculate P1 (Do maintenance)
for idx_s, s in enumerate(S):
    for idx_s_prime, s_prime in enumerate(S):
        (T_prime, d_prime) = s_prime
        if d_prime == 0:
            P[1][idx_s][idx_s_prime] = 1/3

print(f"Transition probability matrices created with shape: {P.shape}")

Transition probability matrices created with shape: (2, 98, 98)


## Policy iteration

In [10]:
gamma = 0.9

# Initialize policy randomly
policy = np.zeros(num_states, dtype=int)
for s in range(num_states):
    if C[s][0] == float("inf"):
        policy[s] = 1  # Use action 1 if action 0 is unavailable

old_policy = np.ones_like(policy)
iterations = 0

while iterations == 0 or np.any(policy != old_policy):
    iterations += 1
    old_policy = policy.copy()

    # Step 1
    P_pi = np.zeros((num_states, num_states))
    C_pi = np.zeros(num_states)
    for s in range(num_states):
        a = policy[s]
        P_pi[s, :] = P[a][s]
        C_pi[s] = C[s][a]
    
    I = np.identity(num_states)
    V = np.linalg.inv(I - gamma * P_pi) @ C_pi

    # Step 2
    for s in range(num_states):
        min_cost = float("inf")
        best_action = policy[s]
        
        for a in range(2):
            if C[s][a] == float("inf"):
                continue  # Skip unavailable actions

            expected_cost = C[s][a] + gamma * np.sum(P[a][s] * V)
            
            if expected_cost < min_cost:
                min_cost = expected_cost
                best_action = a
        
        policy[s] = best_action
    
print(f"Converged at iteration {iterations}")
print(f"Best policy {policy}")


Converged at iteration 4
Best policy [0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1]


In [None]:
print("\n=== Expected Discounted Costs for All States ===")

# Type 1 states
print("\nType 1 Component:")
for d in range(16):  # Type 1 has states 0-15
    state_idx = S.index((1, d))
    print(f"State (1,{d}): {V[state_idx]:.4f}")

# Type 2 states
print("\nType 2 Component:")
for d in range(31):  # Type 2 has states 0-30
    state_idx = S.index((2, d))
    print(f"State (2,{d}): {V[state_idx]:.4f}")

# Type 3 states
print("\nType 3 Component:")
for d in range(51):  # Type 3 has states 0-50
    state_idx = S.index((3, d))
    print(f"State (3,{d}): {V[state_idx]:.4f}")


=== Expected Discounted Costs for All States ===

Type 1 Component:
State (1,0): 0.9508
State (1,1): 0.9992
State (1,2): 1.0501
State (1,3): 1.1020
State (1,4): 1.1528
State (1,5): 1.2036
State (1,6): 1.2663
State (1,7): 1.3715
State (1,8): 1.4736
State (1,9): 1.4736
State (1,10): 1.4736
State (1,11): 1.4736
State (1,12): 1.4736
State (1,13): 1.4736
State (1,14): 1.4736
State (1,15): 5.4736

Type 2 Component:
State (2,0): 0.4563
State (2,1): 0.4792
State (2,2): 0.5032
State (2,3): 0.5285
State (2,4): 0.5550
State (2,5): 0.5828
State (2,6): 0.6121
State (2,7): 0.6428
State (2,8): 0.6750
State (2,9): 0.7089
State (2,10): 0.7445
State (2,11): 0.7818
State (2,12): 0.8210
State (2,13): 0.8621
State (2,14): 0.9053
State (2,15): 0.9508
State (2,16): 0.9992
State (2,17): 1.0501
State (2,18): 1.1020
State (2,19): 1.1528
State (2,20): 1.2036
State (2,21): 1.2663
State (2,22): 1.3715
State (2,23): 1.4736
State (2,24): 1.4736
State (2,25): 1.4736
State (2,26): 1.4736
State (2,27): 1.4736
State (2