In [14]:
from operator import itemgetter
from rl.distribution import Categorical
from typing import Iterable, Iterator, TypeVar, Callable, Mapping, Tuple, Set, Sequence
from rl.markov_decision_process import MarkovDecisionProcess, Policy, \
    TransitionStep, NonTerminal, FiniteMarkovDecisionProcess
from rl.approximate_dynamic_programming import NTStateDistribution
from rl.function_approx import learning_rate_schedule, LinearFunctionApprox, Weights
from rl.distribution import Choose
import rl.iterate as iterate
import itertools
from pprint import pprint
from rl.function_approx import Tabular
from rl.policy import DeterministicPolicy
from rl.monte_carlo import greedy_policy_from_qvf

In [11]:
S = TypeVar('S')
A = TypeVar('A')

### LSPI Algorithm

In [13]:
# LSTDQ update
def lstdq_update(
    transitions: Iterable[TransitionStep[S,A]],
    feature_functions: Sequence[Callable[[S,A], float]],
    target_policy: DeterministicPolicy[S,A],
    gamma: float,
    epsilon: float
) -> LinearFunctionApprox[Tuple[NonTerminal[S], A]]:
    num_feature = len(feature_functions)
    A_inv = np.eye(num_feature)/epsilon
    b = np.zeros((num_feature,1))
    for tr in transitions:
        phi1 = np.array([f((tr.state, tr.action)) for f in feature_functions])
        if isinstance (tr.next_state, NonTerminal):
            phi2 = phi1 - gamma * np.array(
                [f((tr.next_state, target_policy.action_for(tr.next_state))) 
                 for f in feature_functions])
        else:
            phi2 = phi1
        A_inv = A_inv - (A_inv.dot(phi1).dot(phi2.T).dot(A_inv)) / (1 + phi2.T.dot(A_inv).dot(phi1))
        b = b + phi1 * tr.reward
    w = A_inv.dot(b)
    return LinearFunctionApprox.create(feature_functions = feature_functions,
                                       weights = Weights.create(w))

In [15]:
def lspi(
    transitions: Iterable[TransitionStep[S,A]],
    actions: Callable[[NonTerminal[S]], Iterable[A]],
    initial_target_policy: DeterministicPolicy[S, A],
    feature_functions: Sequence[Callable[[S,A], float]],
    gamma: float,
    epsilon: float
) -> Iterator[LinearFunctionApprox[Tuple[NonTerminal[S], A]]]:
    target_policy = initial_target_policy
    transition_seq = list(transitions)
    while True:
        q = lstdq_update(
            transitions = transition_seq,
            feature_functions = feature_functions,
            target_policy = target_policy,
            gamma = gamma,
            epsilon = epsilon
        )
        target_policy = greedy_policy_from_qvf(q, actions)
        yield q