In [24]:
import numpy as np
from tqdm import tqdm
# import random
import math

Define state space $\mathcal{S}$.

In [25]:
states = [
    (i+1,d)
    for i,t in enumerate([15,30,50])
        for d in range(t+1)
]

id_to_state = {i:s for i,s in enumerate(states)}
state_to_id = {s:i for i,s in id_to_state.items()}
print(f"There are {len(states)} states in our MDP.")

There are 98 states in our MDP.


Define action space, $\mathcal{A}$.

In [26]:
actions = [0,1] 
actions_dict = {
    0: "do nothing",
    1: "do maintenance"
}

Create state-action cost matrix $\mathcal{C}^{(a)}_{(T,d)}$.

In [27]:
thresholds = {1:15,2:30,3:50}

C = np.array([
    [0 if d < thresholds[T] else np.inf, 
     1 if d < thresholds[T] else 5]
    for T, d in states
])
print(*C)

[0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [inf  5.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [inf  5.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [0. 1.] [inf  5.]


Get transition probability matrix under action 0; $\mathcal{P}^0$.

In [55]:
def get_zero_prob(pi,lambda_):
    """
    Returns the probability of a zero-inflated Poisson random variable being equal
    to zero: P(P_t+1 = 0). Here pi represents the probability of getting P_t+1 = 0
    deterministically, and (1-pi) represents of drawing from a Poisson distribution.
    """
    return pi + (1-pi)*np.exp(-lambda_)

def get_y_prob(pi,lambda_,y):
    """
    Returns the probability of a zero-inflated Poisson random variable being equal
    to y: P(P_t+1 = y).
    """
    return (1-pi) * ((pow(lambda_,y)*np.exp(-lambda_)) / math.factorial(y))

def get_geq_prob(pi,lambda_,k):
    """
    Returns the probability of a zero-inflated Poisson random variable being
    greater than or equal to k: P(P_t+1 >= k) = 1 - P(P_t+1 < k), where
    k = xi_T - d_t.
    """
    if k == 0:
        return 1.0
    # prob_less_than_k = sum(get_y_prob(pi,lambda_,i) for i in range(0,k))
    prob_less_than_k = get_zero_prob(pi,lambda_) + sum(get_y_prob(pi,lambda_,i) for i in range(1,k))
    return 1 - prob_less_than_k

def transition_prob0(T1, d1, T2, d2):
    """Get transition probability for (T,d) -> (T',d')."""
    th1 = thresholds[T1]
    # if T1 == T2:
    #     if d2 < d1:
    #         return 0
    #     elif d1 == d2:
    #         return get_zero_prob(pi_zero_infl, lambda_zero_infl)
    #     elif d2 < th1:
    #         # P(P_{t+1} = d2 - d1)
    #         return get_y_prob(pi_zero_infl, lambda_zero_infl, d2 - d1)
    #     elif d2 == th1:
    #         # P(P_{t+1} >= th1 - d1)
    #         return get_geq_prob(pi_zero_infl, lambda_zero_infl, th1 - d1)
    #     else:
    #         print(f'ERROR: {(T1,d1)} -> {(T2,d2)}')
    #         return None
    # else:
    #     return 0
    if d1 < th1:
        if T1!=T2 or d2 < d1:
            return 0
        elif d1 == d2:
            return get_zero_prob(pi_zero_infl, lambda_zero_infl)
        elif d2 < th1:
            # P(P_{t+1} = d2 - d1)
            return get_y_prob(pi_zero_infl, lambda_zero_infl, d2 - d1)
        elif d2 == th1:
            # P(P_{t+1} >= th1 - d1)
            return get_geq_prob(pi_zero_infl, lambda_zero_infl, th1 - d1)
        else:
            print(f'ERROR: {(T1,d1)} -> {(T2,d2)}')
            return None
        
    else:
        return 0


pi_zero_infl = 1/2
lambda_zero_infl = 4

P_0 = np.array([
    [transition_prob0(T1, d1, T2, d2) for (T2, d2) in states]
    for (T1, d1) in states
], dtype=object)

for i,row in enumerate(P_0):
    assert round(sum(row),8) == 1 or round(sum(row),8) == 0, (i, sum(row))

Get transition probability matrix under action 1; $\mathcal{P}^1$.

In [56]:
def transition_prob1(T1,d1,T2,d2):
    if T1 not in [1,2,3] or T2 not in [1,2,3]:
        print(f'ERROR: {(T1,d1)} -> {(T2,d2)}')
        return None
    if d2 == 0:
        return 1/3
    else:
        return 0

P_1 = np.array([
    [transition_prob1(T1, d1, T2, d2) for (T2, d2) in states]
    for (T1, d1) in states
], dtype=object)

for i,row in enumerate(P_1):
    assert round(sum(row),8) == 1 or round(sum(row),8) == 0, (i, sum(row)) 

Create transition prob dictionary

In [57]:
P_dict = {0:P_0, 1:P_1}

Initialize discount factor ($\gamma$), the convergence threshold ($\epsilon$), the maximum number of iterations (max\_iter), $\pi_0(s)$, and set $V_0(s)$ to $0$ for all $s\in\mathcal{S}$. Then we perform Value Iteration.

In [None]:
gamma = 0.9
e = pow(10,-8)
max_iter = pow(10,4)
V = np.zeros(len(states))
pi = np.full(len(states),None)

for _ in tqdm(range(max_iter)):
    V_new = np.copy(V)
    max_diff = 0
    
    for (T,d) in states:
        s = state_to_id[(T,d)]
        value_function = []
        for a in actions:
            val = C[s][a] + gamma * np.dot(P_dict[a][s],V)#np.dot(probs[s, :, a], V)
            value_function.append(val)

        V_new[s] = min(value_function)
        pi[s] = np.argmin(value_function)
        max_diff = max(max_diff, abs(V_new[s] - V[s]))
    
    V = V_new
    if max_diff < e:
        break

print("\nOptimal Values:\n", V)
print("\nOptimal Policy:\n", pi)

  2%|▏         | 152/10000 [00:00<00:50, 196.39it/s]


Optimal Values:
 [0.95084707 0.999223   1.05012635 1.10202897 1.15278308 1.20355213
 1.26633731 1.37146276 1.47356793 1.47356793 1.47356793 1.47356793
 1.47356793 1.47356793 1.47356793 5.47356793 0.45629014 0.47918114
 0.50322057 0.52846615 0.55497831 0.5828202  0.61205788 0.64276152
 0.67500719 0.70887705 0.74445371 0.78180797 0.82099388 0.8620824
 0.90525022 0.95084707 0.999223   1.05012635 1.10202897 1.15278308
 1.20355213 1.26633731 1.37146276 1.47356793 1.47356793 1.47356793
 1.47356793 1.47356793 1.47356793 1.47356793 5.47356793 0.17142259
 0.18002247 0.18905378 0.19853818 0.20849838 0.21895826 0.22994289
 0.2414786  0.25359302 0.2663152  0.27967562 0.29370629 0.30844086
 0.32391463 0.34016467 0.35722995 0.37515135 0.39397183 0.41373648
 0.43449268 0.45629014 0.47918114 0.50322057 0.52846615 0.55497831
 0.5828202  0.61205788 0.64276152 0.67500719 0.70887705 0.74445371
 0.78180797 0.82099388 0.8620824  0.90525022 0.95084707 0.999223
 1.05012635 1.10202897 1.15278308 1.20355213 1.




In [68]:
import pandas as pd

In [87]:
results = pd.DataFrame({
    "state":states,
    "T": [T for T,d in states],
    "d": [d for T,d in states],
    "pi*":pi,
    "V*":[round(v,4) for v in V]
})

In [90]:
results.query("T==3")[['state','pi*','V*']]

Unnamed: 0,state,pi*,V*
47,"(3, 0)",0,0.1714
48,"(3, 1)",0,0.18
49,"(3, 2)",0,0.1891
50,"(3, 3)",0,0.1985
51,"(3, 4)",0,0.2085
52,"(3, 5)",0,0.219
53,"(3, 6)",0,0.2299
54,"(3, 7)",0,0.2415
55,"(3, 8)",0,0.2536
56,"(3, 9)",0,0.2663


In [84]:
r'\\begin{tabular}{llrrlr}\n\\toprule\n & state & T & d & pi* & V* \\\\\n\\midrule\n0 & (1, 0) & 1 & 0 & 0 & 0.950847 \\\\\n1 & (1, 1) & 1 & 1 & 0 & 0.999223 \\\\\n2 & (1, 2) & 1 & 2 & 0 & 1.050126 \\\\\n3 & (1, 3) & 1 & 3 & 0 & 1.102029 \\\\\n4 & (1, 4) & 1 & 4 & 0 & 1.152783 \\\\\n5 & (1, 5) & 1 & 5 & 0 & 1.203552 \\\\\n6 & (1, 6) & 1 & 6 & 0 & 1.266337 \\\\\n7 & (1, 7) & 1 & 7 & 0 & 1.371463 \\\\\n8 & (1, 8) & 1 & 8 & 1 & 1.473568 \\\\\n9 & (1, 9) & 1 & 9 & 1 & 1.473568 \\\\\n10 & (1, 10) & 1 & 10 & 1 & 1.473568 \\\\\n11 & (1, 11) & 1 & 11 & 1 & 1.473568 \\\\\n12 & (1, 12) & 1 & 12 & 1 & 1.473568 \\\\\n13 & (1, 13) & 1 & 13 & 1 & 1.473568 \\\\\n14 & (1, 14) & 1 & 14 & 1 & 1.473568 \\\\\n15 & (1, 15) & 1 & 15 & 1 & 5.473568 \\\\\n\\bottomrule\n\\end{tabular}\n'.replace(r'\\',r'\ ')

'\\ begin{tabular}{llrrlr}\\n\\ toprule\\n & state & T & d & pi* & V* \\ \\ \\n\\ midrule\\n0 & (1, 0) & 1 & 0 & 0 & 0.950847 \\ \\ \\n1 & (1, 1) & 1 & 1 & 0 & 0.999223 \\ \\ \\n2 & (1, 2) & 1 & 2 & 0 & 1.050126 \\ \\ \\n3 & (1, 3) & 1 & 3 & 0 & 1.102029 \\ \\ \\n4 & (1, 4) & 1 & 4 & 0 & 1.152783 \\ \\ \\n5 & (1, 5) & 1 & 5 & 0 & 1.203552 \\ \\ \\n6 & (1, 6) & 1 & 6 & 0 & 1.266337 \\ \\ \\n7 & (1, 7) & 1 & 7 & 0 & 1.371463 \\ \\ \\n8 & (1, 8) & 1 & 8 & 1 & 1.473568 \\ \\ \\n9 & (1, 9) & 1 & 9 & 1 & 1.473568 \\ \\ \\n10 & (1, 10) & 1 & 10 & 1 & 1.473568 \\ \\ \\n11 & (1, 11) & 1 & 11 & 1 & 1.473568 \\ \\ \\n12 & (1, 12) & 1 & 12 & 1 & 1.473568 \\ \\ \\n13 & (1, 13) & 1 & 13 & 1 & 1.473568 \\ \\ \\n14 & (1, 14) & 1 & 14 & 1 & 1.473568 \\ \\ \\n15 & (1, 15) & 1 & 15 & 1 & 5.473568 \\ \\ \\n\\ bottomrule\\n\\ end{tabular}\\n'