In [None]:
import numpy as np
from typing import Optional,Mapping,Sequence,Iterable, Iterator, Tuple, TypeVar, Dict, Callable,List
from rl.markov_decision_process import Policy
import math
from rl.distribution import (Bernoulli, Constant, Categorical, Choose,
                             Distribution, FiniteDistribution)
import numpy as np

from rl.distribution import (Bernoulli, Constant, Categorical, Choose,
                             Distribution, FiniteDistribution)
from dataclasses import dataclass, replace
from rl.markov_decision_process import FinitePolicy, TransitionStep
from rl.function_approx import FunctionApprox

In [None]:
S = TypeVar('S')
A = TypeVar('A')

@dataclass(frozen=True)
class Linear_Approx_TDC():
    feature_func: Callable[[S,A],Sequence[float]]
    weight: Sequence[float]
    theta: Sequence[float]

    def update_weight(self, delta_weight:Sequence[float]):
        return replace(self,weight = self.weight + delta_weight)

    def update_theta(self,delta_theta:Sequence[float]):
        return replace(self,weight = self.theta + delta_theta)

    def evaluate(self,state:S, action:A)->float:
        return np.dot(self.weight,self.feature_func(state,action))

    def evaluate_theta(self,state:S, action:A)->float:
        return np.dot(self.theta,self.feature_func(state,action))

def policy_from_q(
        q: Linear_Approx_TDC,
        actions: Mapping[S,Iterable[A]],
        ϵ: float = 0.0
) -> Policy[S, A]:

    explore = Bernoulli(ϵ)

    class QPolicy(Policy[S, A]):
        def act(self, s: S) -> Optional[Distribution[A]]:
            #terminal state?

            if explore.sample():
                return Choose(set(actions))

            ind = np.argmax(q.evaluate([(s, a) for a in actions[s]]))
            return Constant(actions[ind])

    return QPolicy()

def TDC(feature_func: Callable[[S,A],Sequence[float]],     # feature functions
         simulator: Callable[[S,A],Tuple[S,float]],
         w0: Sequence[float],
         theta0: Sequence[float],
         actions: Mapping[S,Iterable[A]],
         gamma: float,
         state_distribution: Distribution[S],
         learning_rate_alpha: Callable[[int],float],
         learning_rate_beta: Callable[[int],float],
         tolerance: float = 1e-6,
         nstop: int = None
         )->Iterator[Sequence[float]]:
    """
    TDC for linear function approx, off-policy control
    feature_func:S->R^d. feature_func(terminal) = 0
    simulator: Take input state and action, output next state and reward

    p0: The initial policy
    w0: R_d, initial weight

    actions: allowed actions for each state

    learning_rate_alpha: learning rate for weight as a function of number of appearance of  a (state,action) pari
    learning_Rate_beta: learning rate for theta as a function of number of appearance of  a (state,action) pari

    return: Iterator of weights R^d
    """

    # initializations

    weight = w0
    theta = theta0
    q = Linear_Approx_TDC(feature_func = feature_func, weight = weight, theta = theta)
    max_steps = round(math.log(tolerance) / math.log(gamma)) if gamma < 1 else nstop

    trace_count = 0
    count_state = {}

    while True:
        state = state_distribution.sample()
        trace_count += 1
        e2 = 1/trace_count
        # for each step in a episode
        step_count = 0
        while step_count < max_steps:
            step_count += 1
            p = policy_from_q(q,e2,actions)
            action = p.act(state).sample()
            next_state,reward = simulator(state,action)

            count_state[(state,action)] = count_state.get((state,action),0.) + 1

            # the off policy next action
            ind = np.argmax([q.evaluate(next_state,action) for action in actions[next_state]])
            ap = actions[next_state][ind]

            phi = feature_func(state,action)
            phi_p = feature_func(next_state,ap)

            # update weight and theta
            delta = reward + gamma*q.evaluate(next_state,ap) - q.evaluate(state,action)
            alpha = learning_rate_alpha(count_state[(state,action)])
            beta = learning_rate_beta(count_state[(state,action)])

            delta_weight = alpha*delta*phi - alpha*gamma*phi_p*q.evaluate_theta(action, state)
            delta_theta = beta*(delta - q.evaluate_theta(action,state))*phi

            state = next_state
            q = q.update_weight(delta_weight = delta_weight)
            q = q.update_theta(delta_theta = delta_theta)


        yield q
