In [1]:
from typing import Iterator, Mapping, Tuple, TypeVar, Sequence, List
from operator import itemgetter
import numpy as np

from rl.distribution import Distribution, Constant
from rl.function_approx import FunctionApprox
from rl.iterate import iterate
from rl.markov_process import (FiniteMarkovRewardProcess, MarkovRewardProcess,
                               RewardTransition)
from rl.markov_decision_process import (FiniteMarkovDecisionProcess, Policy,
                                        MarkovDecisionProcess,
                                        StateActionMapping)
from rl.dynamic_programming import greedy_policy_from_vf
from rl.approximate_dynamic_programming import evaluate_mrp
import itertools

In [2]:
S = TypeVar('S')
A = TypeVar('A')


def policy_iteration(
    mdp: MarkovDecisionProcess[S, A],
    gamma: float,
    approx_0: FunctionApprox[S],
    non_terminal_states_distribution: Distribution[S],
    num_state_samples: int,
    num_policy_evaluation: int
) -> Iterator[FunctionApprox[S]]:
    '''Iteratively calculate the Optimal Value function for the given
    Markov Decision Process, using the given FunctionApprox to approximate the
    Optimal Value function at each step for a random sample of the process'
    non-terminal states.

    '''
    def update(v: FunctionApprox[S]) -> FunctionApprox[S]:
        nt_states: Sequence[S] = non_terminal_states_distribution.sample_n(
            num_state_samples
        )


        def return_(s_r: Tuple[S, float]) -> float:
            s1, r = s_r
            return r + gamma * v.evaluate([s1]).item()

        class greedy_policy(Policy[S, A]):
            mdp: MarkovDecisionProcess
            vf: FunctionApprox
            gamma: float

            def __init__(self, mdp: MarkovDecisionProcess,vf:FunctionApprox,gamma:float):
                self.mdp = mdp
                self.vf.vf
                self.gamma = gamma

            def act(self, s: S) -> Distribution[A]:
                action =max([(a,self.mdp.step(s, a).expectation(return_,))
                     for a in mdp.actions(s)],key = itemgetter(1))[0]
                return Constant(action)

        policy = greedy_policy(
            mdp=mdp,
            vf = v,
            gamma = gamma
            )
        mrp = mdp.apply_policy(policy)
        return itertools.islice(evaluate_mrp(mrp,gamma,v,non_terminal_states_distribution,num_state_samples),num_policy_evaluation)

    return iterate(update, approx_0)