# CME 241: Assigment 4

## Question 3

The state space of this MDP is $\mathcal{S} = \{s \mid s = 1, 2, ..., n, n+1\}$ where $1 \leq s \leq n$ are states for each of the $n$ jobs, and $s = n + 1$ is the unemployed state.

The action space can be defined as $\mathcal{A} = \{c, r\}$ where $c$ represents accepting a new job and $r$ represents rejecting a new job.

The transition function $\mathcal{P}(s, a, s')$ can be computed as $\mathbb{P}\left[S_{t+1} = s' \mid S_t = s, A = a \right]$.

$$\mathcal{P}(s, a, s) = 1 - \alpha : \forall s \neq n + 1, a \in \mathcal{A}$$
$$\mathcal{P}(s, a, s') = 0 : \forall s \neq s', s' \neq n + 1, a \in \mathcal{A}$$
$$\mathcal{P}(s, a, n + 1) = \alpha : \forall s \neq n + 1, a \in \mathcal{A}$$
$$\mathcal{P}(n + 1, c, s') = p_{s'} : \forall s' \neq n + 1$$
$$\mathcal{P}(n + 1, r, s')= 0.0 : \forall s' \neq n + 1$$
$$\mathcal{P}(n + 1, r, n + 1) = 1.0$$
$$\mathcal{P}(n + 1, c, n + 1)= 0.0$$

The reward function $\mathcal{R}(s, a)$ is:

$$\mathcal{R}(s, a) = \log(w_s) : \forall s \neq n + 1, \forall a \in \mathcal{A}$$
$$\mathcal{R}(n + 1, c) = \log \left(\mathbb{E}\left[w_{s'} \mid S_{t+1} = s'\right]\right) = \log \left(\underset{1 \leq i \leq n}{\sum}p_i w_i \right)$$
$$\mathcal{R}(n + 1, r) = \log(w_{0})$$

Now we can write the general Bellman Optimality Equation:

$$ \pmb{V}^*(s) = \underset{a \in \mathcal{A}}{\max}\left\{ \mathcal{R}(s, a) 
    + \gamma \cdot \underset{s' \in \mathcal{S}}{\sum} \mathcal{P}(s, a, s') \cdot \pmb{V}^*(s') \right\}
$$

and solve for $\pmb{V}^*$ with the Value Iteration algorithm. From there, we can back out $\pi^*$, the optimal (deterministic) policy.

$$\pmb{Q}^*(s,a) = \mathcal{R}(s,a) + \gamma \cdot \underset{s' \in \mathcal{S}}{\sum}\mathcal{P}(s, a, s') \cdot \pmb{V}^*(s')$$
$$\pi^*(s) = \underset{a \in \mathcal{A}}{\arg \max} \left\{ \pmb{Q}^*(s,a) \right\}$$

In [1]:
from dataclasses import dataclass, field
from typing import (
    Mapping,
    List, 
    Generic, 
    TypeVar, 
    Set
)

import numpy as np
from pprint import pprint

In [2]:
S = TypeVar("S")

@dataclass
class WageMaximizer(Generic[S]):
    """
    Solves the Wage-Utility Maximization problem via the 
    Bellman Optimality Equation.
    """
    gamma: float
    alpha: float

    employed_states: List[S]
    unemployed_state: S
    
    employed_wages: np.array
    unemployed_wage: float
        
    probs: np.array
        
    def __post_init__(self) -> None:
        """Some more set up."""
        self._validate()
        self.actions: Set[str] = {'c', 'r'}
        self._state_probs = {
            s: p for s, p in 
            zip(self.employed_states, self.probs)
        }
        self._state_wages = {
            s: r for s, r in 
            zip(self.employed_states, self.employed_wages)
        }
        
    def _validate(self) -> None:
        """Validate given problem parametersl."""
        a = len(self.employed_wages)
        b = len(self.probs)
        c = len(self.employed_states)
        if not (a == b == c):
            raise ValueError("Check parameter lengths.")
        if any(x <= 0 for x in self.employed_wages + [self.unemployed_wage]):
            raise ValueError("Must have positive wages.")
        if self.probs.sum() != 1:
            raise ValueError("Check transition probabilities.")
        if self.unemployed_state in self.employed_states:
            raise ValueError(
                f"{self.unemployed_state} cannot also be an employed_state"
            )
    
    @property
    def states(self) -> List[S]:
        return self.employed_states + [self.unemployed_state]
    
    def P(self, state: S, action: str, next_state: S) -> float:
        """Return the (s, a, s') transition probability"""
        if action not in self.actions:
            raise ValueError(f"{a=} is an invalid action")
        m: S = self.unemployed_state
        if (state, action, next_state) == (m, 'c', m):
            return 0
        if (state, action, next_state) == (m, 'r', m):
            return 1
        if (state, action) == (m, 'r'):
            return 0
        if (state, action) == (m, 'c'):
            return self._state_probs[next_state]
        if next_state == m:
            return self.alpha
        if state == next_state:
            return 1 - self.alpha
        return 0
    
    def R(self, state: S, action: str) -> float:
        """Return the expected reward of (s, a)."""
        if action not in self.actions:
            raise ValueError(f"{a=} is an invalid action")
        m: S = self.unemployed_state
        if (state, action) == (m, 'r'):
            return np.log(self.unemployed_wage)
        if (state, action) == (m, 'c'):
            return np.log(self.employed_wages @ self.probs)
        return np.log(self._state_wages[state])
    
    def value_iteration(self) -> np.array:
        """Compute the optimal value function."""
        vk = np.zeros(len(self.states))
        
        def maximize(s: S, vfunc: np.array) -> float:
            """Maximize value function for a state over actions."""
            maximum = float("-inf")
            for a in self.actions:
                _sum = sum(
                    self.P(s, a, s_) * vfunc[j] 
                    for j, s_ in enumerate(self.states)
                )
                val = self.R(s, a) + self.gamma * _sum
                maximum = max(val, maximum)
            return maximum
        
        while True:
            improvement = vk.copy()
            for i, state in enumerate(self.states):
                improvement[i] = maximize(state, vk)
            if np.linalg.norm(improvement - vk) < 1e-5:
                return {s: v for s, v in zip(solver.states, improvement)}
            vk = improvement
            
    def find_optimal_policy(self) -> Mapping[S, str]:
        """Find the an optimal deterministic policy."""
        pi = {}
        v_star = self.value_iteration()
        
        def q_star(s: S, a: str) -> float:
            """Compute Q^* for a (state, action) pair."""
            return (self.R(s, a) + self.gamma * 
                    sum(self.P(s, a, s_) * v_star[s_] 
                        for s_ in self.states))
        
        for s in self.states:
            _max = float("-inf")
            action = None
            for a in self.actions:
                if q_star(s, a) > _max:
                    _max = q_star(s, a)
                    action = a
            pi[s] = action
        return pi
                

In [3]:
gamma: float = 0.1
alpha: float = 0.1

employed_states: List[int] = [1, 2, 3, 4, 5]
unemployed_state: int = 6

employed_wages: List[float] = [2, 1, 4, 7, 3]
unemployed_wage: float = 1
    
probs: List[float] = [0.1, 0.2, 0.4, 0.1, 0.2]
    
solver = WageMaximizer[int](
    gamma=gamma,
    alpha=alpha,
    employed_states=employed_states,
    unemployed_state=unemployed_state,
    employed_wages=np.array(employed_wages),
    unemployed_wage=unemployed_wage,
    probs=np.array(probs)
)
v_star = solver.value_iteration()
v_star

{1: 0.7760895874094604,
 2: 0.014389425423822473,
 3: 1.5377897493950983,
 4: 2.1527521243062777,
 5: 1.221655618964289,
 6: 1.3094432712224158}

In [4]:
# Note that for all the 'employed' states, it doesn't matter what our
# policy is because our action does not affect the transition
# probabilities
opt_policy = solver.find_optimal_policy()
opt_policy

{1: 'r', 2: 'r', 3: 'r', 4: 'r', 5: 'r', 6: 'c'}