In [53]:
import numpy as np

# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple, cast, Dict, Optional, Callable
import plotly.graph_objects as go
from copy import deepcopy
import math
import sys
import plotly.express as px

import io


In [54]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)


In [55]:
N_STATES = 1000
EMPTY_MOVE = 0
P_LEFT = 0.5
R_STEP = 0
R_LEFT = -1
R_RIGHT = 1
SHIFT = 100


class RandomWalk:
    def __init__(self, n_states=None):
        self.n_states = N_STATES if n_states is None else n_states
        self.reset()
        self.r_l = R_LEFT
        self.r_r = R_RIGHT

    def sample_shift(self):
        return int(
            np.sign(np.random.random() - P_LEFT) * np.random.randint(1, SHIFT + 1)
        )

    def step(self, action):
        assert (
            self.state >= 0 and self.state < self.n_states
        ), f"unexpected state encouter: {self.state}"

        shift = self.sample_shift()
        new_state = self.state + shift
        if not (0 <= new_state < self.n_states):
            r = self.r_r if (new_state >= self.n_states) else self.r_l

            r_s = self.n_states if new_state >= self.n_states else -1
            return r_s, r, True, {}

        self.state = new_state
        return self.state, R_STEP, False, {}

    def reset(self):
        self.state = self.n_states // 2
        return self.state

    def seed(self, seed):
        pass


In [56]:
env = RandomWalk()
env.seed(RANDOM_SEED)
env


<__main__.RandomWalk at 0x155315760>

In [57]:
env.reset()

500

In [58]:
State = int
Observation = State

# 0: Left, 1: Right
Action = Literal[0, 1]
Reward = float
Step = Tuple[State, Optional[Action], Optional[Reward]]
Episode = List[Step]

all_states: List[State] = list(range(1000))
all_actions: List[Action] = [0, 1]
nums_of_all_state = len(all_states)
nums_of_all_state_action = len(all_states) * len(all_actions)
allowed_actions: List[List[Action]] = [all_actions for _ in range(nums_of_all_state)]


Feature = np.ndarray
Weight = np.ndarray
Algorithm = Literal["state_agg"]


In [59]:
nums_of_features = int(nums_of_all_state / 100)

In [60]:
class Agent:
    def __init__(
        self,
        env: gym.Env,
        alpha: float,
        mode: Optional[Literal["mc", "td", "td_n"]],
        gamma: float = 1.0,
    ):
        self.alpha = alpha
        self.env = env
        self.mode = mode
        self.gamma = gamma
        self.clear()

    def reset(self):
        self.cur_state: State = self.env.reset()
        self.end = False
        self.episode: Episode = []

    def clear(self):
        self.reset()

        self.omega = np.asarray([np.random.random() for _ in range(nums_of_features)])
        self.episodes: List[Episode] = []

    def step(self) -> Tuple[Observation, bool]:

        assert not self.end, "cannot step on a ended agent"

        (new_state, rwd, stop, _) = self.env.step(0)

        self.episode.append((self.cur_state, 0, rwd))

        self.cur_state = new_state

        if self.mode == "td":
            self.td_evaluate(self.gamma)
        if self.mode == "td_n":
            self.td_n_evaluate(4, self.gamma)

        if stop:
            self.episode.append((self.cur_state, None, None))
            self.episodes.append(self.episode)
            self.episode = []
            self.end = True
            if self.mode == "mc":
                self.mc_evaluate()

        return (cast(State, new_state), stop)

    def td_n_evaluate(self, n: int, gamma: float):
        this_s = self.cur_state
        history = self.episode[-(1 + n) :]
        if len(history) != (1 + n):
            return
        # assert (
        #     len(history) == n + 1
        # ), f"unexpected length of history encountered: {len(history)}"
        (old_s, _, r) = history[0]

        rwd = np.sum(
            [cast(Reward, r) * gamma ** (i) for (i, (_, _, r)) in enumerate(history)]
        )

        # for (i, (s, _, r)) in enumerate(epi[:-1]):
        omega = self.omega
        # next_s = (epi[i + 1])[0]
        self.omega = omega + (
            self.alpha
            * (
                rwd
                + (gamma ** (n + 1))
                * self.linear_predict(this_s, self.omega, self.state_aggregation)
                - self.linear_predict(old_s, self.omega, self.state_aggregation)
            )
            * self.transform_to_feature(old_s, "state_agg")
        )

    def td_evaluate(self, gamma: float):
        return self.td_n_evaluate(0, gamma)

    def mc_evaluate(self):
        epi = self.episodes[-1]

        rwd = cast(Reward, epi[-2][2])
        assert (type(rwd) is float or type(rwd) is int) and not math.isnan(
            rwd
        ), f"unexpected rwd encountered: {rwd}"

        for (state, action, reward) in epi[:-1]:
            omega = self.omega
            self.omega = omega + (
                self.alpha
                * (rwd - self.linear_predict(state, self.omega))
                * self.transform_to_feature(state, "state_agg")
            )

    def state_aggregation(self, s: State) -> Feature:
        assert 0 <= s < len(all_states), f"unexpected state encounter: {s}"
        v = [0 for _ in range(nums_of_features)]
        v[s // 100] = 1
        assert len(v) == 10, f"unexpected length encountered: {len(v)}"
        return np.asarray(v)

    def transform_to_feature(self, s: State, algr: Algorithm) -> Feature:
        if algr == "state_agg":
            return self.state_aggregation(s)
        else:
            raise ValueError(f"unexpected algorithm: {algr}")

    def linear_predict(
        self, s: State, w: Weight, to_feature: Callable[[State], Feature]
    ) -> float:
        assert -1 <= s <= len(all_states), f"unexpected state encounter: {s}"
        if s == -1 or s == len(all_states):
            return 0
        return np.inner(to_feature(s), w)


In [61]:
TOTAL_EPISODES = 100_000

agent = Agent(cast(gym.Env, env), 2e-4, "td")

for _ in range(TOTAL_EPISODES):
    agent.reset()
    end = False
    while not end:
        _, end = agent.step()


In [62]:
agent.episodes[:20]


[[(500, 0, 0),
  (588, 0, 0),
  (506, 0, 0),
  (584, 0, 0),
  (563, 0, 0),
  (493, 0, 0),
  (541, 0, 0),
  (641, 0, 0),
  (671, 0, 0),
  (686, 0, 0),
  (752, 0, 0),
  (719, 0, 0),
  (743, 0, 0),
  (667, 0, 0),
  (696, 0, 0),
  (697, 0, 0),
  (691, 0, 0),
  (709, 0, 0),
  (752, 0, 0),
  (750, 0, 0),
  (792, 0, 0),
  (804, 0, 0),
  (712, 0, 0),
  (812, 0, 0),
  (769, 0, 0),
  (700, 0, 0),
  (748, 0, 0),
  (744, 0, 0),
  (797, 0, 0),
  (781, 0, 0),
  (722, 0, 0),
  (708, 0, 0),
  (757, 0, 0),
  (715, 0, 0),
  (811, 0, 0),
  (812, 0, 0),
  (849, 0, 0),
  (755, 0, 0),
  (798, 0, 0),
  (820, 0, 0),
  (821, 0, 0),
  (777, 0, 0),
  (753, 0, 0),
  (654, 0, 0),
  (749, 0, 0),
  (666, 0, 0),
  (687, 0, 0),
  (659, 0, 0),
  (718, 0, 0),
  (729, 0, 0),
  (717, 0, 0),
  (798, 0, 0),
  (799, 0, 0),
  (752, 0, 0),
  (809, 0, 0),
  (731, 0, 0),
  (728, 0, 0),
  (629, 0, 0),
  (588, 0, 0),
  (684, 0, 0),
  (751, 0, 0),
  (819, 0, 0),
  (834, 0, 0),
  (766, 0, 0),
  (853, 0, 0),
  (929, 0, 0),
  (954, 0,

In [63]:
agent.omega

array([-0.69904806, -0.4879639 , -0.32284916, -0.17802434, -0.04550284,
        0.07921137,  0.21769721,  0.37086087,  0.5335724 ,  0.72505295])

In [64]:
true_values = np.load("./true_values_arr.npy", allow_pickle=False)
true_values


array([-9.21893793e-01, -9.20650119e-01, -9.19397542e-01, -9.18136046e-01,
       -9.16865618e-01, -9.15586241e-01, -9.14297900e-01, -9.13000582e-01,
       -9.11694271e-01, -9.10378953e-01, -9.09054614e-01, -9.07721240e-01,
       -9.06378818e-01, -9.05027334e-01, -9.03666775e-01, -9.02297129e-01,
       -9.00918381e-01, -8.99530521e-01, -8.98133535e-01, -8.96727411e-01,
       -8.95312138e-01, -8.93887703e-01, -8.92454096e-01, -8.91011306e-01,
       -8.89559320e-01, -8.88098129e-01, -8.86627722e-01, -8.85148088e-01,
       -8.83659218e-01, -8.82161101e-01, -8.80653729e-01, -8.79137091e-01,
       -8.77611178e-01, -8.76075981e-01, -8.74531493e-01, -8.72977703e-01,
       -8.71414604e-01, -8.69842188e-01, -8.68260446e-01, -8.66669372e-01,
       -8.65068959e-01, -8.63459197e-01, -8.61840082e-01, -8.60211606e-01,
       -8.58573763e-01, -8.56926546e-01, -8.55269950e-01, -8.53603969e-01,
       -8.51928596e-01, -8.50243828e-01, -8.48549659e-01, -8.46846084e-01,
       -8.45133099e-01, -

In [65]:
fig = go.Figure()
s = list(range(1000))
fig.add_trace(
    go.Scatter(x=[i + 1 for i in s], y=true_values, mode="lines", name="true values")
)
fig.add_trace(
    go.Scatter(
        x=[i + 1 for i in s],
        y=[agent.linear_predict(i, agent.omega, agent.state_aggregation) for i in s],
        mode="lines",
        name="monte-carlo prediction",
    )
)
fig.show()
