In [3]:
from dataclasses import dataclass
from typing import Callable, Tuple, Optional, Mapping
import numpy as np
import itertools
from rl.distribution import Categorical, Constant
from rl.markov_process import MarkovRewardProcess
from rl.gen_utils.common_funcs import get_logistic_func, get_unit_sigmoid_func


Define

$$a_k(s) =E[f(S_{t+k})|S_t = s]$$.

We have recursive relation

$\begin{equation}
E[f(S_{t+k})|S_t = s] = \sum_{S_{t+1}}P(S_{t}=s,S_{t+1}) E[f(S_{t+k})|S_{t+1}]
\end{equation}$

Therefore we can get

$\begin{equation}
a_k(s) = \sum_{s'} P(s,s') a_{k-1}(s')
\end{equation}$

In this stock price case, we have that

$\begin{equation}
a_k(s) = p\times a_{k-1}(s+1) + (1-p) \times a_{k-1}(s-1) \tag{1}
\end{equation}$

The value function is defined as

$\begin{equation}
V(s) = \sum_{k=1}^{\infty} \gamma^{k-1} a_k(s)
\end{equation}$

Sum $k$ on the both sides of Eq $(1)$ we get recursive equation for $V(s)$

$V(s) = pf(s+1) + (1-p)f(s-1) + p \gamma V(s+1) + (1-p)\gamma V(s-1)$

We can calculate $V(s)$ recursively. The problem is that we have infinite number of states and there are no "initial condition" provided.
Noting that $gamma < 1, p<1$, we can solve this problem by presetting some accuracy requirement. And if $\gamma^n < accuracy$, we stop the recursion.

In [9]:
@dataclass(frozen=True)
class StateMP1:
    price: int

@dataclass
class StockPriceMRP1(MarkovRewardProcess[StateMP1]):
    level_param: int  # level to which price mean-reverts
    gamma: float   # gamma for gain
    f: Callable[[int], float]     # function for reward
    alpha1: float = 0.25  # strength of mean-reversion (non-negative value)
    accuracy: float = 1e-3

    def up_prob(self, state: StateMP1) -> float:
        return get_logistic_func(self.alpha1)(self.level_param - state.price)

    def transition_reward(self, state: StateMP1) -> Categorical[Tuple[StateMP1, float]]:
        up_p = self.up_prob(state)

        return Categorical({
            (StateMP1(state.price + 1),self.f(StateMP1(state.price + 1))): up_p,
            (StateMP1(state.price - 1),self.f(StateMP1(state.price + 1))): 1 - up_p
        })

    def get_value_function(self,state:StateMP1)\
            ->float:
        def helper(state:StateMP1,order:int):
            result = 0
            if self.gamma**order <= self.accuracy:
                return result

            p = self.up_prob(state)
            result = p*self.f(state.price+1) + (1-p)*self.f(state.price+1)+ \
                     p*self.gamma*helper(StateMP1(price = state.price+1),order+1) + \
                     (1-p)*self.gamma*helper(StateMP1(price=   state.price-1),order+1)
            return result
        return helper(state,0)


In [20]:
def reward_function(price:float)->float:
    return price/3

gamma:float = 0.5
level_param: int = 100
accuracy:float = 1e-3


mp = StockPriceMRP1(gamma = gamma,level_param=level_param,f = reward_function,accuracy=accuracy)
mp.get_value_function(StateMP1(price=50))




34.62629541699808