In [1]:
from dataclasses import dataclass
from rl.distribution import Constant
from rl.chapter2.simple_inventory_mrp import SimpleInventoryMRPFinite
from rl.markov_process import MarkovRewardProcess, NonTerminal, TransitionStep
from typing import Dict, Iterable, TypeVar

import itertools
import numpy as np


In [2]:
S = TypeVar("S")


@dataclass
class TabularMC:
    mrp: MarkovRewardProcess
    gamma: float

    def trace_reward(self, trace, gamma: float, tol: float = 1e-16):
        if gamma != 1.0:
            num_steps: int = np.ceil(np.log(tol) / np.log(gamma)).astype(int)
            reward_trace = np.array(
                [state.reward for state in itertools.islice(trace, num_steps)]
            )
            discounted_reward = np.dot(
                reward_trace, np.array([gamma ** k for k in np.arange(num_steps)])
            )
            return discounted_reward
        else:
            *_, last = trace
            return last.reward

    def mc_tabular_vf(self, n_traces: int = 1_000):
        vf_dict: Dict[NonTerminal[S], float] = {}
        for init_state in self.mrp.non_terminal_states:
            traces: Iterable[TransitionStep[S]] = self.mrp.reward_traces(
                Constant(init_state)
            )
            mc_mean = np.mean(
                [
                    self.trace_reward(trace=trace, gamma=self.gamma)
                    for trace in itertools.islice(traces, n_traces)
                ]
            )
            vf_dict[init_state] = mc_mean
        return vf_dict


In [3]:
si_mrp = SimpleInventoryMRPFinite(
    capacity=2,
    poisson_lambda=1.0,
    holding_cost=1.0,
    stockout_cost=10.0,
)
tmc = TabularMC(mrp=si_mrp, gamma=0.9)


In [4]:
tmc.mc_tabular_vf(n_traces=5_000)


{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.51460522703688,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.903565953345836,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.320435951976002,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.89292092757909,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.340624146825913,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.478455937938158}