In [10]:
import itertools
import numpy as np
from typing import Iterable, Callable, Mapping, TypeVar, List, Dict
import rl.markov_process as mp
from rl.returns import returns
from rl.function_approx import Tabular
from rl.iterate import last
from collections import defaultdict
from rl.monte_carlo import mc_prediction
from rl.td import td_prediction
from rl.chapter2.simple_inventory_mrp import SimpleInventoryMRPFinite, InventoryState
from rl.distribution import Constant

In [67]:
S = TypeVar('S')

def tabular_mc_prediction(traces: Iterable[Iterable[mp.ReturnStep[S]]],
                  count_to_weight_func:Callable[[int], float],
                  gamma: float,
                  tolerance: float = 1e-6) -> List[Dict[S,float]]:
    episodes = (returns(trace, gamma, tolerance) for trace in traces)
    vf_list: List[Mapping[S,float]] = []
    occur: Dict[S,float] = defaultdict(int)
    for epi in episodes:
        if vf_list:
            vn = vf_list[-1]
        else:
            vn: Dict[S,float] = defaultdict(int)
        for si in epi:
            occur[si.state]+=1
            vn[si.state] = vn[si.state]+(si.return_
                                         -vn[si.state])*count_to_weight_func(occur[si.state])
        vf_list.append(vn)
        
    return(vf_list)  

In [68]:
si_mrp = SimpleInventoryMRPFinite(
    capacity = 2,
    poisson_lambda = 1.0,
    holding_cost = 1.0,
    stockout_cost = 10.0)

start_state = InventoryState(on_hand = 0, on_order = 0)
samples = [list(itertools.islice(si_mrp.simulate_reward(Constant(start_state)),
                                 1000)) for _ in range(100)]

In [69]:
mcp = last(mc_prediction(sample,Tabular(),0.9)).values_map
print(mcp)

{InventoryState(on_hand=0, on_order=0): -108.33831543184229, InventoryState(on_hand=0, on_order=2): -98.13087306461235, InventoryState(on_hand=1, on_order=0): -100.76881049885476, InventoryState(on_hand=0, on_order=1): -98.73014965964698, InventoryState(on_hand=1, on_order=1): -99.69178857203374, InventoryState(on_hand=2, on_order=0): -101.5532032548051}


In [70]:
tabular_mcp = tabular_mc_prediction(sample, lambda n: 1./n, 0.9)[-1]
print(tabular_mcp)

defaultdict(<class 'int'>, {InventoryState(on_hand=0, on_order=0): -108.33831543184242, InventoryState(on_hand=0, on_order=2): -98.13087306461223, InventoryState(on_hand=1, on_order=0): -100.76881049885482, InventoryState(on_hand=0, on_order=1): -98.7301496596479, InventoryState(on_hand=1, on_order=1): -99.69178857203383, InventoryState(on_hand=2, on_order=0): -101.55320325480523})
