In [1]:
from dataclasses import dataclass
from rl.distribution import Distribution, Constant
from rl.chapter2.simple_inventory_mrp import InventoryState, SimpleInventoryMRPFinite
from rl.markov_process import TransitionStep, MarkovRewardProcess, NonTerminal
from typing import Dict, Iterable, TypeVar

import itertools

In [2]:
S = TypeVar("S")


@dataclass
class DynamicLearningRate:
    alpha: float
    beta: float
    H: int

    def evaluate(self, n: int):
        return self.alpha / (1 + ((n - 1) / self.H) ** (self.beta))


@dataclass
class TabularTD:
    mrp: MarkovRewardProcess
    gamma: float
    dlr: DynamicLearningRate

    def __post_init__(self):
        self.vf_dict: Dict[NonTerminal[S], float] = {
            state: 0 for state in self.mrp.non_terminal_states
        }
        self.counter_dict: Dict[NonTerminal[S], int] = {
            state: 0 for state in self.mrp.non_terminal_states
        }

    def process_trace(
        self,
        trace: Iterable[TransitionStep[S]],
        n_iter: int = 1_000,
    ):
        for state in itertools.islice(trace, n_iter):
            self.counter_dict[state.state] += 1
            self.vf_dict[state.state] = self.vf_dict[state.state] + (
                state.reward
                + self.gamma * self.vf_dict[state.next_state]
                - self.vf_dict[state.state]
            ) * self.dlr.evaluate(self.counter_dict[state.state])

    def td_tabular_vf(
        self, init_state_dist: Distribution[NonTerminal[S]], n_traces: int = 1_000
    ):
        traces: Iterable[TransitionStep[S]] = self.mrp.reward_traces(init_state_dist)
        for trace in itertools.islice(traces, n_traces):
            self.process_trace(trace=trace)


In [3]:
si_mrp = SimpleInventoryMRPFinite(
    capacity=2,
    poisson_lambda=1.0,
    holding_cost=1.0,
    stockout_cost=10.0,
)

dlr = DynamicLearningRate(alpha=0.03, beta=0.5, H=10000)

ttd = TabularTD(mrp=si_mrp, gamma=0.9, dlr=dlr)
ttd.td_tabular_vf(
    init_state_dist=Constant(NonTerminal(InventoryState(on_hand=0, on_order=0))),
    n_traces=1_000,
)
ttd.vf_dict


{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.605140581688616,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -28.02149311468794,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.343959269000017,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -29.051903582844695,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.350196698337818,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.512487959886048}