# CME 241 -- Assignment 13

In [1]:
from collections import defaultdict
from typing import Dict, List, Tuple
import operator
import itertools as it

import numpy as np

from rl.markov_decision_process import (
    FiniteMarkovDecisionProcess as FiniteMDP, 
    FinitePolicy,
    TransitionStep
)
from rl.chapter3.simple_inventory_mdp_cap import (
    SimpleInventoryMDPCap, 
    InventoryState
)
from rl.distribution import Constant, Choose

In [2]:
def compute_G(
    episode: List[TransitionStep]
) -> Dict[Tuple[InventoryState, int], float]:
    """Compute total return G for each state, action pair."""
    G = {}
    for i, step in enumerate(episode):
        pair = (step.state, step.action)
        if pair not in G:
            G[pair] = sum(x.reward for x in episode[i:])
    return G

In [27]:
def get_greedy_policy(
    mdp: FiniteMDP[InventoryState, int],
    Q: Dict[Tuple[InventoryState, int], float]
) -> FinitePolicy[InventoryState, int]:
    """Construct the greedy policy from the Q values."""
    mapping = {}
    for state in mdp.states():
        if not mdp.is_terminal(state):
            actions = filter(lambda x: x[0] == state, Q)
            greedy = max(actions, key=lambda x: Q[x])[1]
            mapping[state] = Constant(greedy)
        else:
            mapping[state] = None
            
    return FinitePolicy(mapping)

In [28]:
def monte_carlo_control(
    mdp: FiniteMDP[InventoryState, int]
) -> FinitePolicy[InventoryState, int]:
    """Find optimal policy by iteratively evaluating candidates."""
    Q = {
        (s, a): np.random.randn() 
        for s in mdp.non_terminal_states
        for a in mdp.actions(s)
    }
    
    policy_map = {}
    for state in mdp.states():
        if mdp.is_terminal(state):
            policy_map[state] = None
        else:
            action = np.random.choice(list(mdp.actions(state)))
            policy_map[state] = Constant(action)
    policy = FinitePolicy(policy_map)
    
    returns = defaultdict(list)
    
    starts = Choose({
        s: 1 / len(mdp.states()) 
        for s in mdp.non_terminal_states
    })
    for k in range(100):
        episode = list(it.islice(mdp.simulate_actions(starts, policy), 1009))
        G = compute_G(episode)
        for pair in G:
            returns[pair].append(G[pair])
            Q[pair] = sum(returns[pair]) / len(returns[pair])
        policy = get_greedy_policy(mdp, Q)
        
    return policy

In [29]:
mdp = SimpleInventoryMDPCap(
    capacity=2,
    poisson_lambda=1.0,
    holding_cost=1.0,
    stockout_cost=10.0
)

In [31]:
mcc_policy = monte_carlo_control(mdp)
mcc_policy

For State InventoryState(on_hand=0, on_order=0):
  Do Action 1 with Probability 1.000
For State InventoryState(on_hand=0, on_order=1):
  Do Action 1 with Probability 1.000
For State InventoryState(on_hand=0, on_order=2):
  Do Action 0 with Probability 1.000
For State InventoryState(on_hand=1, on_order=0):
  Do Action 0 with Probability 1.000
For State InventoryState(on_hand=1, on_order=1):
  Do Action 0 with Probability 1.000
For State InventoryState(on_hand=2, on_order=0):
  Do Action 0 with Probability 1.000