In [13]:
import numpy as np
from tabulate import tabulate

In [14]:
transitions = {
    'Poor_Unknown': {
        'save': [('Poor_Unknown', 1, 0)],
        'advertise': [('Poor_Famous', 0.5, 0), ('Poor_Unknown', 0.5, 0)]
    },
    'Poor_Famous': {
        'save': [('Rich_Famous', 0.5, 10), ('Poor_Unknown', 0.5, 0)],
        'advertise': [('Poor_Famous', 1, 0)]
    },
    'Rich_Unknown': {
        'save': [('Poor_Unknown', 0.5, 0), ('Rich_Unknown', 0.5, 10)],
        'advertise': [('Poor_Unknown', 0.5, 0), ('Poor_Famous', 0.5, 0)]
    },
    'Rich_Famous': {
        'save': [('Rich_Unknown', 0.5, 10), ('Rich_Famous', 0.5, 10)],
        'advertise': [('Poor_Famous', 1, 0)]
    }
}

In [15]:
class MDP:
    def __init__(self, transitions):
        self.states = list(transitions.keys())
        self.actions = list(transitions[self.states[0]].keys())
        self.transitions = transitions
        
mdp = MDP(transitions)

In [16]:
value_matrix = [[(0.0, 0), (0.0, 0) , (10.0, 0), (10.0, 0)]]

In [17]:
mdp.actions

['save', 'advertise']

In [18]:
def func(discount_factor, horizons):
    if horizons == 0:
        return np.array([0,0,10,10])
    else:
        prev = func(discount_factor, horizons - 1)
        V_PU = np.array([0.9*(prev[0]), 0.9*(0.5 * prev[0] + 0.5 * prev[1])])
        V_PF = np.array([discount_factor * (0.5 * prev[3] + 0.5 * prev[0]), discount_factor * (prev[1])])
        V_RU = np.array([10 + discount_factor * (0.5 * prev[2] + 0.5 * prev[0]), 10 + discount_factor * (0.5 * prev[1] + 0.5 * prev[0])])
        V_RF = np.array([10 + discount_factor * (0.5 * prev[2] + 0.5 * prev[3]), 10 + discount_factor * prev[1]])
        value_matrix.append([(np.max(V_PU), mdp.actions[np.argmax(V_PU)]), (np.max(V_PF), mdp.actions[np.argmax(V_PF)]), (np.max(V_RU), mdp.actions[np.argmax(V_RU)]), (np.max(V_RF), mdp.actions[np.argmax(V_RF)])])
        #print(* [np.max(V_PU), np.max(V_PF), np.max(V_RU), np.max(V_RF)])
        return np.array([np.max(V_PU), np.max(V_PF), np.max(V_RU), np.max(V_RF)])
    

In [19]:
_ = func(0.9, 6)

In [20]:
print(tabulate(value_matrix, headers=['V(PU)', 'V(PF)', 'V(RU)', 'V(RF)'], tablefmt='fancy_grid', numalign='center', stralign='center'))

╒═══════════════════════════════════╤══════════════════════════════╤══════════════════════════════╤══════════════════════════════╕
│               V(PU)               │            V(PF)             │            V(RU)             │            V(RF)             │
╞═══════════════════════════════════╪══════════════════════════════╪══════════════════════════════╪══════════════════════════════╡
│             (0.0, 0)              │           (0.0, 0)           │          (10.0, 0)           │          (10.0, 0)           │
├───────────────────────────────────┼──────────────────────────────┼──────────────────────────────┼──────────────────────────────┤
│           (0.0, 'save')           │        (4.5, 'save')         │        (14.5, 'save')        │        (19.0, 'save')        │
├───────────────────────────────────┼──────────────────────────────┼──────────────────────────────┼──────────────────────────────┤
│       (2.025, 'advertise')        │        (8.55, 'save')        │       (16.525,