# Setup

Run the following cell to initialize the notebook.

In [1]:
import numpy as np
import warnings

class FiniteMDPWarning(UserWarning):
    pass

class FiniteMDP:
    def __init__(self, states, actions, P, R, seed=None):
        self.states = states.copy()
        self.size = len(self.states)
        self.state_index = {s: i for i, s in enumerate(states)}
        self.actions = actions.copy()        
        self.P = {
            action:  np.array(M, dtype=np.float64, copy=True)
            for action, M in P.items()
        }

        self.CP = {}
        for action, P_a in self.P.items():        
            self.CP[action] = np.cumsum(P_a, axis=1)
            self.CP[action][:, -1] = 1.0
        self.CP[action][:, -1] = 1.0

        self.R = {
            action:  np.array(M, dtype=np.float64, copy=True)
            for action, M in R.items()
        }
        
        self.P[None] = np.eye(self.size, dtype=np.float64)
        self.R[None] = np.zeros(shape=(self.size, self.size), dtype=np.float64)
        self.terminal_indices = {i for i, state in enumerate(states) if not actions[state]}

        self.rng = np.random.default_rng(seed)
        self.current_state_index = None
        self.terminated = True        

    def _policy_value_function_lu(self, indexed_policy, gamma):
        PP = np.array([self.P[indexed_policy[i]][i] for i in range(self.size)], dtype=np.float64)
        RR = np.array([self.R[indexed_policy[i]][i] for i in range(self.size)], dtype=np.float64)
        b = np.sum(PP * RR, axis=1)
        A = np.eye(self.size) - gamma * PP
        for i in range(self.size):
            if i in self.terminal_indices:
                A[i][i] = 1.0
                b[i] = 0
        return np.linalg.solve(A, b)

    def _policy_value_function_jacobi(self, indexed_policy, gamma, stop_tol, max_iterations, start=None):
        if start is None:
            V = np.zeros(self.size, dtype=np.float64)
        else:
            V = start.copy()
        PP = np.array([self.P[indexed_policy[i]][i] for i in range(self.size)], dtype=np.float64)
        RR = np.array([self.R[indexed_policy[i]][i] for i in range(self.size)], dtype=np.float64)
        b = np.sum(PP * RR, axis=1)
        for _ in range(max_iterations):
            Vnew = b + gamma * PP @ V
            sup_norm =  np.max(np.abs(V - Vnew))
            if stop_tol is not None and sup_norm < stop_tol:
                break
            V = Vnew
        else:
            if stop_tol is not None:
                warnings.warn(
                    f'Maximum number of iterations reached in policy_value_function. Final delta: {sup_norm:5.3e}',
                    FiniteMDPWarning,
                    stacklevel=3,
                )   
        return V
        
    def _policy_value_function_gs(self, indexed_policy, gamma, stop_tol, max_iterations, start=None):
        if start is None:
            V = np.zeros(self.size, dtype=np.float64)
        else:
            V = start.copy()
        PP = np.array([self.P[indexed_policy[i]][i] for i in range(self.size)], dtype=np.float64)
        RR = np.array([self.R[indexed_policy[i]][i] for i in range(self.size)], dtype=np.float64)
        b = np.sum(PP * RR, axis=1)
        max_delta = -np.inf
        for _ in range(max_iterations):
            sup_norm = 0.0
            for i in range(self.size):
                new_value = b[i] + gamma * PP[i] @ V
                sup_norm = max(sup_norm, abs(new_value - V[i]))
                V[i] = new_value
            if stop_tol is not None and sup_norm < stop_tol:
                break
        else:
            if stop_tol is not None:
                warnings.warn(
                    f'Maximum number of iterations reached in policy_value_function. Final delta: {sup_norm:5.3e}',
                    FiniteMDPWarning,
                    stacklevel=3,
                )   
        return V
        
    def _policy_value_function(self, policy, gamma, method, stop_tol, max_iterations, start, return_type):
        indexed_policy = self.size * [None]
        for state, action in policy.items():
            indexed_policy[self.state_index[state]] = action
        match method:
            case 'lu':
                VV = self._policy_value_function_lu(indexed_policy, gamma)
            case 'jacobi':
                VV = self._policy_value_function_jacobi(indexed_policy, gamma, stop_tol, max_iterations, start)
            case 'gs':
                VV = self._policy_value_function_gs(indexed_policy, gamma, stop_tol, max_iterations, start)
            case _:
                raise ValueError("method should be 'lu', 'jacobi' or 'gs'")

        match return_type:
            case 'dict':
                return {state: VV[self.state_index[state]] for state in self.states}
            case 'array':
                return VV
            case _:
                raise ValueError("return_type must be 'dict' or 'array'")

    def policy_value_function(self, policy, gamma=1, method='lu', stop_tol=1E-8, max_iterations=100, start=None, return_type='dict'):
        if start is not None:
            start = np.array([start[s] for s in self.states])
        return self._policy_value_function(policy, gamma, method, stop_tol, max_iterations, start, return_type)

    def value_iteration(self, gamma=1, stop_tol=1E-8, max_iterations=100, start=None):
        if start is None:
            V = np.zeros(self.size, dtype=np.float64)
        else:
            V = np.array([start[s] for s in states], dtype=np.float64)
            for i in range(self.size):
                if i in self.terminal_indices:
                    V[i] = 0.0
        for _ in range(max_iterations):
            max_delta = 0.0
            for i, state in enumerate(self.states):
                if i in self.terminal_indices:
                    continue
                new_value = -np.inf
                for action in self.actions[state]:
                    new_value = max(new_value, 
                                    sum(self.P[action][i, j] * (self.R[action][i, j] + gamma * V[j]) for j in range(self.size)))
                if stop_tol is not None:
                    max_delta = max(max_delta, abs(V[i] - new_value))
                V[i] = new_value
            if stop_tol is not None and max_delta < stop_tol:
                break
        else:
                warnings.warn(
                    f'Maximum number of iterations reached in policy_value_function. Final delta: {max_delta:5.3e}',
                    FiniteMDPWarning,
                    stacklevel=2,
                )   
        # Compute optimal policy
        policy = {}
        for i,state in enumerate(self.states):
            if i in self.terminal_indices:
                policy[state] = None
                continue
            max_value = -np.inf
            max_action = None
            for action in self.actions[state]:
                new_value = sum(self.P[action][i, j] * (self.R[action][i, j] + gamma * V[j]) for j in range(self.size))
                if new_value > max_value:
                    max_value = new_value
                    max_action = action
            policy[state] = max_action

        return {state: V[i] for i, state in enumerate(self.states)}, policy

    def policy_iteration(self, gamma=1, method='lu', stop_tol=1E-8, max_iterations=100, relaxations=20, start=None):
        if start is None:
            V = np.zeros(self.size, dtype=np.float64)
        else:
            V = np.array([start[s] for s in states], dtype=np.float64)
            for i in range(self.size):
                if i in self.terminal_indices:
                    V[i] = 0.0

        # Initialize random policy
        policy = {}
        for i, state in enumerate(self.states):
            if i in self.terminal_indices:
                policy[state] = None
                continue
            policy[state] = self.rng.choice(self.actions[state])

        # Compute approximate value of current policy
        V = self._policy_value_function(policy, gamma, method, None, relaxations, None, 'array')

        for _ in range(max_iterations):
            # Compute improved policy
            new_policy = {}
            for i, state in enumerate(self.states):
                if i in self.terminal_indices:
                    new_policy[state] = None
                    continue
                max_value = -np.inf
                max_action = None
                for action in self.actions[state]:
                    new_value = sum(self.P[action][i, j] * (self.R[action][i, j] + gamma * V[j]) for j in range(self.size))
                    if new_value > max_value:
                        max_value = new_value
                        max_action = action
                new_policy[state] = max_action

            # Stop criterion 1: no change in optimal policy
            if new_policy == policy:
                 break
            # Compute approximate value of improved policy
            Vnew = self._policy_value_function(new_policy, gamma, method, None, relaxations, V, 'array')
            # Stop criterion 2: Change in V smaller than stop_tol
            if np.max(np.abs(Vnew - V)) < stop_tol:
                break
            # Update for next iteration
            V = Vnew
            policy = new_policy
        else:
            warnings.warn(
                f'Maximum number of iterations reached in policy_value_function. Final delta: {max_delta:5.3e}',
                FiniteMDPWarning,
                stacklevel=2,
            )   
        # Compute higher precision approximation for value function of final policy
        V = self._policy_value_function(policy, gamma, method, stop_tol, max_iterations, V, 'dict')

        return V, policy

    def reset(self, initial_state=None):
        if initial_state is None:
            while True:
                index = self.rng.integers(0, len(self.states))
                if index not in self.terminal_indices:
                    break
            self.current_state_index = index
        else:
            self.current_state_index = self.states.index(initial_state)
        self.terminated = False
        return self.states[self.current_state_index]

    def step(self, action):
        if self.current_state_index is None:
            raise RuntimeError('MDP not initialized, call reset() before calling step() for the first time')
        if self.terminated:
            raise RuntimeError('run terminated, call reset() to start a new run')

        u = self.rng.random()
        next_state_index = np.searchsorted(self.CP[action][self.current_state_index], u, side='right')
        next_state =  self.states[next_state_index]
        self.terminated = next_state_index in self.terminal_indices
        reward = self.R[action][self.current_state_index, next_state_index]
        self.current_state_index = next_state_index
        state = self.states[self.current_state_index]
        
        return (state, reward, self.terminated)
        
    def sarsa(
        self,
        gamma=1.0,
        alpha=0.1,
        epsilon=0.1,
        n_episodes=100_000,
        max_steps_per_episode=10_000,
        seed=None,
    ):
        """
        Tabular SARSA(0) consistent with (P, R).

        - Rewards on transitions into terminal states are allowed.
        - Terminal states have no actions (actions[state] is None).
        - No bootstrapping from terminal states.
        """
        rng = np.random.default_rng(seed)

        # ------------------------------------------------------------
        # Initialize Q(s,a) only for non-terminal states
        # ------------------------------------------------------------
        Q = {}
        for s in self.states:
            if self.actions[s] is None:
                continue
            for a in self.actions[s]:
                Q[(s, a)] = 0.0

        # ------------------------------------------------------------
        # Epsilon-greedy policy (greedy over Q)
        # ------------------------------------------------------------
        def epsilon_greedy_action(state):
            actions = self.actions[state]
            if actions is None:
                return None

            if rng.random() < epsilon:
                return rng.choice(actions)

            q_vals = [Q[(state, a)] for a in actions]
            return actions[int(np.argmax(q_vals))]

        # ------------------------------------------------------------
        # SARSA learning loop
        # ------------------------------------------------------------
        for _ in range(n_episodes):

            state = self.reset()
            state_index = self.state_index[state]

            if state_index in self.terminal_indices:
                continue

            action = epsilon_greedy_action(state)

            for _ in range(max_steps_per_episode):

                i = self.state_index[state]

                next_state, _, terminated = self.step(action)
                j = self.state_index[next_state]

                # IMPORTANT: reward always comes from R
                reward = self.R[action][i, j]

                if terminated:
                    # no bootstrap from terminal states
                    Q[(state, action)] += alpha * (
                        reward - Q[(state, action)]
                    )
                    break

                next_action = epsilon_greedy_action(next_state)

                Q[(state, action)] += alpha * (
                    reward
                    + gamma * Q[(next_state, next_action)]
                    - Q[(state, action)]
                )

                state = next_state
                action = next_action

        # ------------------------------------------------------------
        # Extract greedy policy and value function
        # ------------------------------------------------------------
        policy = {}
        V = {}

        for s in self.states:
            idx = self.state_index[s]

            if idx in self.terminal_indices:
                policy[s] = None
                V[s] = 0.0
                continue

            actions = self.actions[s]
            q_vals = [(Q[(s, a)], a) for a in actions]
            best_q, best_a = max(q_vals, key=lambda x: x[0])

            policy[s] = best_a
            V[s] = best_q

        return Q, policy, V

# System description

We consider an wharehose that can hold up to $M$ items in inventory. The goal of the problem is to minimize costs in managing this inventory. This wharehouse holds a single kind of item.

We assume that every day up to $K$ items can be ordered, and denote by $p_j$ the probability that $j$ items are ordered in each day ($j=0,1,\ldots,K)$. Ordered items, if available, are removed from inventory by the beginning of the next day.

Each day, the inventory manager can request up to $L$ items to replenish inventory, which will be delivered by the beginning of each day.

The following are the operational costs in the system:

- There is a holding cost $h_i$ for having $i$ items in inventory. We only assume that $h_i>h_j$ if $i>j$, but the way these costs grow can have any form (they can be linear, quadratic, or we can just assume some other kind of structure (convex, concave, etc.)

- There is an ordering cost $b_i$ for requesting $i$ items to be added to inventory. This is also an increasing cost ($b_i>b_j$ if $i>j$), but we can have situations in which there is a "jump" in the costs (requesting $i+1$ items is **much** more expensive than requesting $i$ items.

- There is a cost for not being able to deliver an item that is delivered. One simple hypothesis is that there is a unit cost $c$ for each lost item. However, we may want associate different costs for losing more than one order.

Your task it to think how this could be formalized in an MDP model. To do this, you will have to come up with a very precise description of:

- The state space.

- The actions.

- The transition probability matrices and reward matrices.

Keep in mind that any model that has a chance to be realistic cannot be too small. So, for example the warehouse size $M$ should be something like 20 or so. So, our matrices will be $21\times 21$, and we will have to use some programming even to set up these matrices!

So, don't worry about it now. Think about what kind of data structures we would need to represent the parameters in the problem. If you don't know any programming, don't worry: think just in terms of mathematical modeling! We will figure out together how to come up with a mathematical model.