# CME 241 (Winter 2021) -- Assignment 11

In [1]:
from typing import Iterable, TypeVar, Mapping, List, Tuple
import itertools as it
from collections import defaultdict
from pprint import pprint

from rl.markov_process import TransitionStep, ReturnStep
from rl.returns import returns
from rl.chapter2.simple_inventory_mrp import (
    SimpleInventoryMRPFinite, 
    InventoryState
)
from rl.distribution import Choose
from rl.function_approx import LinearFunctionApprox
from rl.monte_carlo import mc_prediction
from rl.td import td_prediction

## Question 1

In [2]:
S = TypeVar("S")
    
def tabular_mc_prediction(
    traces: Iterable[Iterable[TransitionStep[S]]],
    gamma: int = 1
) -> Mapping[S, float]:
    """Compute value function for finite number of states."""
    episodes = (returns(trace, gamma, 1e-5) for trace in traces)
    vf: Mapping[S, float] = defaultdict(float)
    counts: Mapping[S, int] = defaultdict(lambda: int(1))
    for n, episode in enumerate(episodes):
        for rs in episode:
            vf[rs.state] += (rs.return_ - vf[rs.state]) / counts[rs.state]
            counts[rs.state] += 1
    return vf

## Question 2

In [3]:
def tabular_td_prediction(
    traces: Iterable[TransitionStep[S]],
    gamma: int = 1
) -> Mapping[S, float]:
    """Compute value function for finite number of states."""
    V: Mapping[S, float] = defaultdict(float)
    for n, t in enumerate(traces, start=1):
        V[t.state] += 20 * (t.reward + gamma * V[t.next_state] - V[t.state]) / n
    return V

## Question 3

In [4]:
user_capacity = 2
user_poisson_lambda = 1.0
user_holding_cost = 1.0
user_stockout_cost = 10.0

user_gamma = 0.9

si_mrp = SimpleInventoryMRPFinite(
    capacity=user_capacity,
    poisson_lambda=user_poisson_lambda,
    holding_cost=user_holding_cost,
    stockout_cost=user_stockout_cost,
)
print("True Value Function")
print("--------------")
si_mrp.display_value_function(gamma=user_gamma)
print()

numsims = 2_000
distrib = Choose({s for s in si_mrp.states()})
traces = it.islice(si_mrp.reward_traces(distrib), numsims)

iters = it.tee(traces, 3)

print("Tabular MC Value Function")
print("-------------------------")
pprint(tabular_mc_prediction(iters[0], user_gamma))
print()

# Exhaustive list of states, identity functions
f0 = lambda x: int(x == InventoryState(0, 0))
f1 = lambda x: int(x == InventoryState(1, 0))
f2 = lambda x: int(x == InventoryState(2, 0))
f3 = lambda x: int(x == InventoryState(1, 1))
f4 = lambda x: int(x == InventoryState(0, 1))
f5 = lambda x: int(x == InventoryState(0, 2))

mc_approximator = LinearFunctionApprox.create(
    feature_functions=[f0, f1, f2, f3,f4, f5]
)
mc_candidates = list(mc_prediction(iters[1], mc_approximator, user_gamma))
mc_sol = mc_candidates[-1]
print("Approx. MC Value Function")
print("-------------------------")
pprint({s: mc_sol.evaluate([s]) for s in si_mrp.states()})
print()

flattened = it.chain.from_iterable(it.islice(x, 500) for x in iters[2])
flats = it.tee(flattened)

td_vf = tabular_td_prediction(flats[0], user_gamma)
print("Tabular TD Value Function")
print("-------------------------")
pprint(td_vf)
print()

td_approximator = LinearFunctionApprox.create(
    feature_functions=[f0, f1, f2, f3,f4, f5]
)
td_candidates = list(td_prediction(flats[1], td_approximator, user_gamma))
td_sol = td_candidates[-1]
print("Approx. TD Value Function")
print("-------------------------")
pprint({s: td_sol.evaluate([s]) for s in si_mrp.states()})

True Value Function
--------------
{InventoryState(on_hand=0, on_order=0): -35.511,
 InventoryState(on_hand=1, on_order=0): -28.932,
 InventoryState(on_hand=0, on_order=1): -27.932,
 InventoryState(on_hand=0, on_order=2): -28.345,
 InventoryState(on_hand=2, on_order=0): -30.345,
 InventoryState(on_hand=1, on_order=1): -29.345}

Tabular MC Value Function
-------------------------
defaultdict(<class 'float'>,
            {InventoryState(on_hand=0, on_order=2): -28.349594723613464,
             InventoryState(on_hand=2, on_order=0): -30.31382739372025,
             InventoryState(on_hand=0, on_order=1): -27.953161963072763,
             InventoryState(on_hand=1, on_order=1): -29.328717040092105,
             InventoryState(on_hand=0, on_order=0): -35.51232707680943,
             InventoryState(on_hand=1, on_order=0): -28.92084417311761})

Approx. MC Value Function
-------------------------
{InventoryState(on_hand=0, on_order=0): array([-1.90212916]),
 InventoryState(on_hand=1, on_order=0)