### 3)

In [378]:
from dataclasses import dataclass
import numpy as np
import os
import pprint
import sys
sys.path.append(os.path.abspath("/Users/justincramer/Documents/Coding/CME241/RL-book/"))

from typing import Iterable, Iterator, TypeVar, Callable, Mapping, Iterable
from rl.distribution import Categorical, Choose
from rl.approximate_dynamic_programming import (ValueFunctionApprox,
                                                QValueFunctionApprox,
                                                NTStateDistribution)
from rl.iterate import converge, converged, last
from rl.markov_decision_process import MarkovDecisionProcess, Policy, \
    TransitionStep, NonTerminal
import rl.markov_process as mp
from rl.returns import returns
import itertools

S = TypeVar('S')

#### a)

In [379]:
def tabular_mc_prediction(
    traces: Iterable[Iterable[mp.TransitionStep[S]]],
    γ: float,
    episode_length_tolerance: float = 1e-6
) -> Iterator[Mapping[S, float]]:
    
    counts: Mapping[S, int] = dict() # State VF update frequency
    vf: Mapping[S, float] = dict() # State VF approximation
    
    episodes: Iterator[Iterator[mp.ReturnStep[S]]] = \
        (returns(trace, γ, episode_length_tolerance) for trace in traces)
    
    yield vf
    for episode in episodes:
        for step in episode:
            counts[step.state] = counts.get(step.state, 0) + 1
            alpha = 1 / counts[step.state]
            vf[step.state] = (1 - alpha) * vf.get(step.state, 0) + alpha * step.return_
            yield vf

#### b)

In [189]:
def tabular_td_prediction(
    traces: Iterable[Iterable[mp.TransitionStep[S]]],
    γ: float
) -> Iterator[Mapping[S, float]]:
    
    counts: Mapping[S, int] = dict() # State VF update frequency
    vf: Mapping[S, float] = dict() # State VF approximation
        
    yield vf
    for trace in traces:
        for step in trace:
            counts[step.state] = counts.get(step.state, 0) + 1
            alpha = 1 / counts[step.state]
            vf[step.state] = (1 - alpha) * vf.get(step.state, 0) \
                             + alpha * (step.reward + γ * vf.get(step.next_state, 0))
            yield vf

#### c)

In [23]:
from rl.chapter2.simple_inventory_mrp import SimpleInventoryMRPFinite

In [27]:
# Value function from SimpleInventoryMRPFinite
user_capacity = 2
user_poisson_lambda = 1.0
user_holding_cost = 1.0
user_stockout_cost = 10.0

user_gamma = 0.9

si_mrp = SimpleInventoryMRPFinite(
    capacity=user_capacity,
    poisson_lambda=user_poisson_lambda,
    holding_cost=user_holding_cost,
    stockout_cost=user_stockout_cost
)

si_mrp.display_value_function(gamma=user_gamma)


{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.511,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.932,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.345,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.932,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.345,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.345}


In [303]:
# Function to define convergence of VF dicts
def done(d1: Mapping[S, float], d2: Mapping[S, float], tol=1e-6):
    keys = d1.keys()
    for key in keys:
        if (abs(d1[key] - d2[key]) > tol): return False
    return True

In [381]:
init_distribution: Choose[S] = Choose(si_mrp.non_terminal_states)
traces: Iterable[Iterable[mp.TransitionStep[S]]] = si_mrp.reward_traces(init_distribution)
predictions_mc = tabular_mc_prediction(traces, γ=user_gamma)
predictions_td = tabular_td_prediction(traces, γ=user_gamma)

# can't get this to work; it doesn't return a dictionary
print(converged(predictions_mc, done), '\n')
# this works and converges correctly
pprint.pprint(last(itertools.islice(predictions_mc, 10000)))
print('\n')
# this doesn't converge correctly
pprint.pprint(last(itertools.islice(predictions_td, 10000)))

{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -24.019125179698424} 

{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.754361675907454,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.167499531405678,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -27.963602651795853,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.079887737385864,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.680867266975728,
 NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.1603606475684}


{NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -16.667807473317065,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -15.902976672608476,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -17.09986436532476,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -15.56750237890788,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -17.89682783620878,
 NonTerminal(state=InventoryState(on_hand