In [1]:
from dataclasses import dataclass
from typing import Dict, Iterator, Mapping, Tuple
from rl.distribution import Categorical
from rl.markov_decision_process import FiniteMarkovDecisionProcess
from rl.policy import FinitePolicy, FiniteDeterministicPolicy

import itertools as it
import numpy as np


In [84]:
# create the state mapping:


@dataclass
class Froggie_State(int):
    lilypad_idx: int

    def __hash__(self):
        return super().__hash__()


@dataclass
class Froggie_Croak(str):
    croak: str

    def __hash__(self):
        return super().__hash__()


@dataclass
class Froggie_StateActionTransition:

    num_lilypads: int
    reward_method: str

    @property
    def get_stateaction_map(
        self,
    ) -> Mapping[
        Froggie_State, Mapping[Froggie_Croak, Categorical[Tuple[Froggie_State, float]]]
    ]:
        self.stateaction_map = {}
        for init_pad in np.arange(1, self.num_lilypads):
            for croak in {"A", "B"}:
                self.add_stateaction(state=Froggie_State(init_pad), croak=croak)
        return self.stateaction_map

    def add_stateaction(self, state: Froggie_State, croak: str) -> None:
        if state in self.stateaction_map.keys():

            self.stateaction_map[state].update(
                {croak: self.prob(action=croak, state=state)}
            )
        else:
            self.stateaction_map[state] = {croak: self.prob(action=croak, state=state)}

    def reward(self, state: Froggie_State, next_state: Froggie_State):
        if self.reward_method == "linear":
            return next_state.lilypad_idx
        elif self.reward_method == "comparative_linear":
            return next_state.lilypad_idx - state.lilypad_idx
        elif self.reward_method == "escape":
            return next_state.lilypad_idx == self.num_lilypads
        else: raise NotImplementedError()

    def prob(
        self,
        action: str,
        state: Froggie_State,
    ) -> Categorical[Tuple[Froggie_State, float]]:
        dist = {}
        if action == "A":
            down_state = Froggie_State(state.lilypad_idx - 1)
            dist[(down_state, self.reward(state, down_state))] = (
                state.lilypad_idx / self.num_lilypads
            )

            up_state = Froggie_State(state.lilypad_idx + 1)
            dist[(up_state, self.reward(state, up_state))] = (
                self.num_lilypads - state.lilypad_idx
            ) / self.num_lilypads
        elif action == "B":
            for pad_idx in range(0, self.num_lilypads + 1):
                if pad_idx == state.lilypad_idx:
                    continue
                else:
                    next_state = Froggie_State(pad_idx)
                    dist[(next_state, self.reward(state, next_state))] = (
                        1 / self.num_lilypads
                    )

        else:
            raise NotImplementedError()
        return Categorical(dist)

    # need use of reward: dictionary (state, reward) : probability


In [108]:
f = Froggie_StateActionTransition(6, "escape")

In [109]:
mdp = FiniteMarkovDecisionProcess(f.get_stateaction_map)

In [110]:
# create all possible finite policies

def get_all_policy_maps(f: Froggie_StateActionTransition) -> Iterator[FinitePolicy]:
    def croak_policy_from_subset(A_subset: Tuple) -> FiniteDeterministicPolicy:
        policy_map: Dict[Froggie_State, Froggie_Croak] = {}
        for init_pad in np.arange(1, f.num_lilypads):
            if init_pad in A_subset:
                policy_map[Froggie_State(init_pad)] = Froggie_Croak("A")
            else:
                policy_map[Froggie_State(init_pad)] = Froggie_Croak("B")
        return FiniteDeterministicPolicy(policy_map)

    policy_iterator = it.chain([], [])
    for r in np.arange(1, f.num_lilypads):
        policy_iterator = it.chain(
            policy_iterator, it.combinations(iterable=np.arange(1, f.num_lilypads), r=r)
        )
    return map(croak_policy_from_subset, policy_iterator)#; croak_policy_from_subset(A_subset=x))


In [111]:
for policy in get_all_policy_maps(f=f):
    mrp = mdp.apply_finite_policy(policy)
    print(np.max(mrp.get_value_function_vec(gamma = 1)))
    

0.4615384615384616
0.5000000000000002
0.5000000000000002
0.5000000000000002
0.6153846153846154
0.4333333333333334
0.4615384615384616
0.4615384615384617
0.5833333333333334
0.5000000000000002
0.5000000000000001
0.6153846153846156
0.5000000000000002
0.6153846153846156
0.6666666666666666
0.3999999999999999
0.43333333333333335
0.5595854922279793
0.4615384615384616
0.5833333333333334
0.6373056994818653
0.5000000000000002
0.6153846153846156
0.6666666666666667
0.7142857142857143
0.3404255319148935
0.5312499999999999
0.6153846153846152
0.6875
0.7872340425531914
0.615384615384615


In [89]:
mrp.get_value_function_vec(gamma = 1)

array([0.42857143, 0.47619048, 0.47619048, 0.47619048, 0.47619048,
       0.47619048, 0.47619048, 0.47619048, 0.47619048])