In [6]:
import numpy as np
from typing import Mapping, Sequence, Tuple
from dataclasses import dataclass

In [65]:
@dataclass
class OptimalJobs():
    value_function: Mapping[Tuple[int, str], float]
    probabilities: Sequence[float]
    wages: Sequence[float]
    gamma: float
    alpha: float
    def transition_probability(self, in_state: Tuple[int, str], out_state: Tuple[int, str], action: str) -> float:
        in_job = in_state[0]
        in_employ = in_state[1]
        out_job = out_state[0]
        out_employ = out_state[1]
        if (action == 'R') & (in_employ == 'N') & (out_employ == 'N'):
            return self.probabilities[out_job - 1]
        if (action == 'A') & (out_employ == 'N'):
            return self.alpha * self.probabilities[out_job - 1]
        if (in_job == out_job) & (out_employ == 'E'):
            return 1 - self.alpha
        else:
            return 0
    def expected_reward(self, in_state: Tuple[int, str], action: 'str') -> float:
        if action == 'A':
            return self.wages[in_state[0]]
        elif action == 'R':
            return self.wages[0]
    
    def vf_if_employ(self, in_state: Tuple[int, str]) -> float:
        value = np.log(self.wages[in_state[0]]) + self.gamma * (1 - self.alpha) \
                * self.value_function[tuple([in_state[0], 'E'])]
        other_vals = np.array([self.value_function[tuple([i, 'N'])] for i in range(1, len(self.probabilities) + 1)])
        value += np.sum(other_vals * self.probabilities) * self.gamma * self.alpha
        return value
    
    def vf_if_notemploy(self, in_state: Tuple[int, str]) -> float:
        value = np.log(self.wages[0]) + self.gamma \
                * np.sum(np.array([self.value_function[tuple([i, 'N'])] for i in range(1, len(self.probabilities) + 1)])
                * self.probabilities)
        value = max(value, self.vf_if_employ(in_state))
        return value
    
    def bellman_operator(self):
        vf_copy = self.value_function.copy()
        for in_state in vf_copy.keys():
            if in_state[1] == 'E':
                val = self.vf_if_employ(in_state)
            elif in_state[1] == 'N':
                val = self.vf_if_notemploy(in_state)
            vf_copy[in_state] = val
        self.value_function = vf_copy
        
    def get_actions(self) -> Mapping[Tuple[int, str], str]:
        vf_copy = self.value_function.copy()
        for in_state in vf_copy.keys():
            if in_state[1] == 'E':
                vf_copy[in_state] = 'A'
            if in_state[1] == 'N':
                if self.vf_if_employ(in_state) >= self.vf_if_notemploy(in_state):
                    vf_copy[in_state] = 'A'
                else:
                    vf_copy[in_state] = 'R'
        return vf_copy
    
    def compute_value_function(self, tolerance: float = 1e-4) -> Mapping[Tuple[int, str], float]:
        keys = self.value_function.keys()
        diff = float('inf')
        while diff > tolerance:
            old_vals = np.array([self.value_function[key] for key in keys])
            self.bellman_operator()
            new_vals = np.array([self.value_function[key] for key in keys])
            diff = np.max(np.abs(new_vals - old_vals))
        return self.value_function
        

In [66]:
n = 10
rands = np.random.rand(10)
probs = rands / rands.sum()
wages = np.random.rand(11) * 100
dict_keys = [tuple([i, c]) for c in ['E', 'N'] for i in range(1, n + 1)]
value_function = dict(zip(dict_keys, np.zeros(n * 2)))
gamma = 0.5
alpha = 0.2

In [67]:
optim = OptimalJobs(value_function=value_function, probabilities=probs, wages=wages, gamma=gamma, alpha=alpha)

In [68]:
optim.compute_value_function()

{(1, 'E'): 8.965454174298525,
 (2, 'E'): 8.728637113705796,
 (3, 'E'): 8.858567476725499,
 (4, 'E'): 6.771465687378746,
 (5, 'E'): 8.731514108195988,
 (6, 'E'): 7.187558652229709,
 (7, 'E'): 7.944154937384902,
 (8, 'E'): 7.486429367781601,
 (9, 'E'): 7.158429294244973,
 (10, 'E'): 6.974224725439922,
 (1, 'N'): 8.965454174298525,
 (2, 'N'): 8.728637113705796,
 (3, 'N'): 8.858567476725499,
 (4, 'N'): 8.401314289192019,
 (5, 'N'): 8.731514108195988,
 (6, 'N'): 8.401314289192019,
 (7, 'N'): 8.401314289192019,
 (8, 'N'): 8.401314289192019,
 (9, 'N'): 8.401314289192019,
 (10, 'N'): 8.401314289192019}

In [69]:
optim.get_actions()

{(1, 'E'): 'A',
 (2, 'E'): 'A',
 (3, 'E'): 'A',
 (4, 'E'): 'A',
 (5, 'E'): 'A',
 (6, 'E'): 'A',
 (7, 'E'): 'A',
 (8, 'E'): 'A',
 (9, 'E'): 'A',
 (10, 'E'): 'A',
 (1, 'N'): 'A',
 (2, 'N'): 'A',
 (3, 'N'): 'A',
 (4, 'N'): 'R',
 (5, 'N'): 'A',
 (6, 'N'): 'R',
 (7, 'N'): 'R',
 (8, 'N'): 'R',
 (9, 'N'): 'R',
 (10, 'N'): 'R'}