In [113]:
import numpy as np

# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple, cast, Dict, Optional, Callable, Protocol, Union
import plotly.graph_objects as go
from copy import deepcopy
from abc import abstractmethod, ABC
import math
import sys
import plotly.express as px

import io


In [114]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)


In [115]:
N_STATES = 1000
EMPTY_MOVE = 0
P_LEFT = 0.5
R_STEP = 0
R_LEFT = -1
R_RIGHT = 1
SHIFT = 100


class RandomWalk:
    def __init__(self, n_states=None):
        self.n_states = N_STATES if n_states is None else n_states
        self.reset()
        self.r_l = R_LEFT
        self.r_r = R_RIGHT

    def sample_shift(self):
        return int(
            np.sign(np.random.random() - P_LEFT) * np.random.randint(1, SHIFT + 1)
        )

    def step(self, action):
        assert (
            self.state >= 0 and self.state < self.n_states
        ), f"unexpected state encouter: {self.state}"

        shift = self.sample_shift()
        new_state = self.state + shift
        if not (0 <= new_state < self.n_states):
            r = self.r_r if (new_state >= self.n_states) else self.r_l

            r_s = self.n_states if new_state >= self.n_states else -1
            return r_s, r, True, {}

        self.state = new_state
        return self.state, R_STEP, False, {}

    def reset(self):
        self.state = self.n_states // 2
        return self.state

    def seed(self, seed):
        pass


In [116]:
env = RandomWalk()
env.seed(RANDOM_SEED)
env


<__main__.RandomWalk at 0x113654640>

In [117]:
env.reset()

500

In [118]:
State = int
Observation = State

# 0: Left, 1: Right
Action = Literal[0, 1]
Reward = float
Step = Tuple[State, Optional[Action], Optional[Reward]]
Episode = List[Step]

all_states: List[State] = list(range(1000))
all_actions: List[Action] = [0, 1]
nums_of_all_state = len(all_states)
nums_of_all_state_action = len(all_states) * len(all_actions)
allowed_actions: List[List[Action]] = [all_actions for _ in range(nums_of_all_state)]


Feature = np.ndarray
Weight = np.ndarray


In [119]:
nums_of_features = int(nums_of_all_state / 100)

In [120]:
class FeatureInterface(Protocol):
    len: int

    @abstractmethod
    def to_feature(self, s: State) -> Feature:
        raise NotImplementedError()


class Fourier(FeatureInterface):
    def __init__(self, order: int):
        self.order = order
        self.len = self.order + 1

    def to_feature(self, s: State) -> Feature:
        order = self.order
        s_c = s / nums_of_all_state
        return np.asarray([np.cos(o * np.pi * s_c) for o in range(order + 1)])


class StateAgg(FeatureInterface):
    def __init__(self, n_groups: int):
        self.n_groups = n_groups
        self.len = n_groups

    def to_feature(self, s: State) -> Feature:
        n_groups = self.n_groups

        v = [0 for _ in range(n_groups)]
        states_per_group = math.ceil(nums_of_all_state / n_groups)

        v[s // states_per_group] = 1

        return np.asarray(v)


class CoarseCoding(FeatureInterface):
    def __init__(self, slots: List[Tuple[float, float]]):
        self.slots = slots
        self.len = len(self.slots)

    def to_feature(self, s: State) -> Feature:
        s_c = s / nums_of_all_state

        v = [1 if l <= s_c <= r else 0 for (l, r) in self.slots]

        return np.asarray(v)


class RadialBasis(FeatureInterface):
    def __init__(self, norms: List[Tuple[float, float]]):
        self.norms = norms
        self.len = len(self.norms)

    def to_feature(self, s: State) -> Feature:
        s_c = s / nums_of_all_state

        v = [np.exp(-(np.abs(s_c - c) ** 2) / (2 * sigma)) for (c, sigma) in self.norms]
        return np.asarray(v)


class Tiling(FeatureInterface):
    def __init__(self, n_tilings: int):
        if n_tilings // 2 == 0:
            n_tilings = n_tilings - 1

        assert n_tilings > 0, f"number of tilings cannot be lower than 1: {n_tilings}"

        self.n_tilings = n_tilings
        self.n_tiles_per_tiling = n_tilings
        self.len = n_tilings * (2 * n_tilings - 1)

        self.compute_tilings()

    def compute_tilings(self):
        states_per_tile = math.ceil(nums_of_all_state / self.n_tiles_per_tiling)
        pivot_tiling = [
            (i * states_per_tile, (i + 1) * states_per_tile)
            for i in range(self.n_tiles_per_tiling)
        ]
        all_tiles = [
            [pivot_tiling]
            if i == 0
            else [
                self.move_tiling(pivot_tiling, i * 0.13 * nums_of_all_state),
                self.move_tiling(pivot_tiling, -i * 0.13 * nums_of_all_state),
            ]
            for i in range(self.n_tilings)
        ]
        self.all_tiles = [tile for tiles in all_tiles for tile in tiles]

    def to_feature(self, s: State) -> Feature:

        v = [1 if lp <= s <= rp else 0 for tile in self.all_tiles for (lp, rp) in tile]
        return np.asarray(v)

    def move_tiling(
        self, tiling: List[Tuple[int, int]], delta: float
    ) -> List[Tuple[int, int]]:
        return [(np.round(lp + delta), np.round(rp + delta)) for (lp, rp) in tiling]


class AppxInterface(Protocol):
    @abstractmethod
    def predict(self, f: Feature, w: Weight) -> float:
        raise NotImplemented()

    @abstractmethod
    def gradient(self, f: Feature, w: Weight) -> np.ndarray:
        raise NotImplemented()


class Linear(AppxInterface):
    def predict(self, f: Feature, w: Weight) -> float:

        return np.inner(f, w)

    def gradient(self, f: Feature, w: Weight) -> np.ndarray:
        return f


class AlgorithmInterface(Protocol):
    n_of_omega: int
    appx: AppxInterface
    feature: FeatureInterface

    @abstractmethod
    def after_step(self, cur_state: State, episode: Episode, omega: np.ndarray):
        raise NotImplementedError()

    @abstractmethod
    def on_termination(self, episode: Episode, omega: np.ndarray):
        raise NotImplementedError()

    def predict(self, s: State, w: Weight) -> float:
        assert -1 <= s <= len(all_states), f"unexpected state encounter: {s}"

        if s == -1 or s == len(all_states):
            return 0

        return self.appx.predict(self.feature.to_feature(s), w)

    def gradient(self, s: State, w: Weight) -> np.ndarray:
        assert 0 <= s < len(all_states), f"unexpected state encounter: {s}"

        return self.appx.gradient(self.feature.to_feature(s), w)


class MonteCarlo(AlgorithmInterface):
    def __init__(
        self,
        alpha: float,
        appx_algorithm: AppxInterface,
        feature_algorithm: FeatureInterface,
    ):
        self.alpha = alpha
        self.appx = appx_algorithm
        self.feature = feature_algorithm

        self.n_of_omega = self.feature.len

    def after_step(self, cur_state: State, episode: Episode, omega: np.ndarray):
        pass

    def on_termination(self, episode: Episode, omega: np.ndarray):
        # epi = self.episodes[-1]
        epi = episode

        rwd = cast(Reward, epi[-2][2])
        assert (type(rwd) is float or type(rwd) is int) and not math.isnan(
            rwd
        ), f"unexpected rwd encountered: {rwd}"

        # delta_omega = np.asarray([])
        for (state, _, _) in epi[:-1]:
            # omega = self.omega
            omega += (
                self.alpha
                * (rwd - self.appx.predict(self.feature.to_feature(state), omega))
                * self.appx.gradient(self.feature.to_feature(state), omega)
            )


class TD(AlgorithmInterface):
    def __init__(
        self,
        alpha: float,
        appx_algorithm: AppxInterface,
        feature_algorithm: FeatureInterface,
        gamma: float = 1.0,
    ):
        self.alpha = alpha
        self.gamma = gamma
        self.appx = appx_algorithm
        self.feature = feature_algorithm

        self.n_of_omega = self.feature.len

    def after_step(self, cur_state: State, episode: Episode, omega: np.ndarray):
        n = 1
        this_s = cur_state
        history = episode[-(1 + n) :]
        gamma = self.gamma

        if len(history) != (1 + n):
            return

        (old_s, _, r) = history[0]

        rwd = np.sum(
            [cast(Reward, r) * gamma ** (i) for (i, (_, _, r)) in enumerate(history)]
        )

        omega = omega + (
            self.alpha
            * (
                rwd
                + (gamma ** (n + 1))
                * self.appx.predict(self.feature.to_feature(this_s), omega)
                - self.appx.predict(self.feature.to_feature(old_s), omega)
            )
            * self.appx.gradient(self.feature.to_feature(old_s), omega)
        )

    def on_termination(self):
        pass


class TDN(AlgorithmInterface):
    def __init__(
        self,
        n: int,
        alpha: float,
        appx_algorithm: AppxInterface,
        feature_algorithm: FeatureInterface,
        gamma: float = 1.0,
    ):
        self.n = n
        self.alpha = alpha
        self.appx = appx_algorithm
        self.feature = feature_algorithm
        self.gamma = gamma

        self.n_of_omega = self.feature.len

    def on_termination(self, episode: Episode, omega: np.ndarray):
        n = self.n

        rwd = cast(Reward, episode[-2][2])
        assert (type(rwd) is float or type(rwd) is int) and not math.isnan(
            rwd
        ), f"unexpected rwd encountered: {rwd}"

        for i in range((n + 1), 1, -1):
            history = episode[-i:-1]
            assert (
                1 <= len(history) < n + 1
            ), f"unexpected length encountered during termination: {len(history)}, {n+1}"

            (old_s, _, r) = history[0]

            omega += (
                self.alpha
                * (rwd - self.appx.predict(self.feature.to_feature(old_s), omega))
                * self.appx.gradient(self.feature.to_feature(old_s), omega)
            )

    def after_step(self, this_s: State, episode: Episode, omega: np.ndarray):
        n = self.n
        history = episode[-(1 + n) :]
        gamma = self.gamma

        if len(history) != (1 + n):
            return

        (old_s, _, r) = history[0]

        rwd = np.sum(
            [cast(Reward, r) * gamma ** (i) for (i, (_, _, r)) in enumerate(history)]
        ) + (gamma ** (n + 1)) * self.appx.predict(
            self.feature.to_feature(this_s), omega
        )

        omega += (
            self.alpha
            * (rwd - self.appx.predict(self.feature.to_feature(old_s), omega))
            * self.appx.gradient(self.feature.to_feature(old_s), omega)
        )


In [121]:
class Agent:
    def __init__(
        self,
        env: gym.Env,
        algm: AlgorithmInterface,
        gamma: float = 1.0,
    ):
        self.env = env
        self.algm = algm
        self.gamma = gamma
        self.clear()

    def reset(self):
        self.cur_state: State = self.env.reset()
        self.end = False
        self.episode: Episode = []

    def clear(self):
        self.reset()

        self.omega = np.asarray(
            [np.random.random() for _ in range(self.algm.n_of_omega)]
        )
        self.episodes: List[Episode] = []

    def step(self) -> Tuple[Observation, bool]:
        assert not self.end, "cannot step on a ended agent"

        (new_state, rwd, stop, _) = self.env.step(0)

        self.episode.append((self.cur_state, 0, rwd))

        self.cur_state = new_state

        self.algm.after_step(self.cur_state, self.episode, self.omega)

        if stop:
            self.episode.append((self.cur_state, None, None))
            self.episodes.append(self.episode)
            self.episode = []
            self.end = True

            self.algm.on_termination(self.episodes[-1], self.omega)

        return (cast(State, new_state), stop)

    def predict(self, s: State) -> float:
        return self.algm.predict(s, self.omega)


In [122]:
TOTAL_EPISODES = 10_0000

agent = Agent(
    cast(gym.Env, env),
    TDN(9, 2e-4, Linear(), Fourier(20))
    # Policy(
    #     kind="mc",
    #     detail=MonteCarlo(),
    #     predict_algorithm=Linear(),
    #     # feature_algorithm=Tiling(5)
    #     # feature_algorithm=RadialBasis(
    #     #     [
    #     #         (0.2, 0.1),
    #     #         (0.3, 0.3),
    #     #         (0.5, 0.25),
    #     #         (0.5, 0.5),
    #     #         (0.5, 1),
    #     #         (0.7, 0.3),
    #     #         (0.8, 0.1),
    #     #     ]
    #     # )
    #     # feature_algorithm=CoarseCoding(
    #     #     [
    #     #         (0, 0.25),
    #     #         (0.1, 0.35),
    #     #         (0.15, 0.25),
    #     #         (0.25, 0.5),
    #     #         (0.3, 0.55),
    #     #         (0.4, 0.6),
    #     #         (0.44, 0.67),
    #     #         (0.5, 0.75),
    #     #         (0.6, 0.9),
    #     #         (0.66, 0.95),
    #     #         (0.7, 1.0),
    #     #         (0.83, 1.0),
    #     #     ]
    #     # ),
    #     feature_algorithm=Fourier(10),
    # ),
)

for _ in range(TOTAL_EPISODES):
    agent.reset()
    end = False
    while not end:
        _, end = agent.step()


In [123]:
agent.episodes[:20]


[[(500, 0, 0),
  (530, 0, 0),
  (545, 0, 0),
  (611, 0, 0),
  (578, 0, 0),
  (602, 0, 0),
  (526, 0, 0),
  (555, 0, 0),
  (556, 0, 0),
  (550, 0, 0),
  (568, 0, 0),
  (611, 0, 0),
  (609, 0, 0),
  (651, 0, 0),
  (663, 0, 0),
  (571, 0, 0),
  (671, 0, 0),
  (628, 0, 0),
  (559, 0, 0),
  (607, 0, 0),
  (603, 0, 0),
  (656, 0, 0),
  (640, 0, 0),
  (581, 0, 0),
  (567, 0, 0),
  (616, 0, 0),
  (574, 0, 0),
  (670, 0, 0),
  (671, 0, 0),
  (708, 0, 0),
  (614, 0, 0),
  (657, 0, 0),
  (679, 0, 0),
  (680, 0, 0),
  (636, 0, 0),
  (612, 0, 0),
  (513, 0, 0),
  (608, 0, 0),
  (525, 0, 0),
  (546, 0, 0),
  (518, 0, 0),
  (577, 0, 0),
  (588, 0, 0),
  (576, 0, 0),
  (657, 0, 0),
  (658, 0, 0),
  (611, 0, 0),
  (668, 0, 0),
  (590, 0, 0),
  (587, 0, 0),
  (488, 0, 0),
  (447, 0, 0),
  (543, 0, 0),
  (610, 0, 0),
  (678, 0, 0),
  (693, 0, 0),
  (625, 0, 0),
  (712, 0, 0),
  (788, 0, 0),
  (813, 0, 0),
  (787, 0, 0),
  (849, 0, 0),
  (816, 0, 0),
  (887, 0, 0),
  (873, 0, 0),
  (848, 0, 0),
  (867, 0,

In [124]:
agent.omega

array([-1.39384229e-03, -8.13967752e-01, -6.14174797e-04, -5.69954545e-02,
       -3.41769713e-03, -2.64941981e-02, -4.44287444e-03, -1.07264144e-02,
        5.59458977e-03, -7.10588800e-03, -2.82886789e-03,  3.92836664e-04,
       -4.14427753e-03, -3.61335222e-03,  9.07743275e-04, -3.30978768e-03,
        4.67761280e-03,  1.96290329e-03, -3.23052283e-03,  2.17469121e-03,
        2.71073165e-03])

In [125]:
true_values = np.load("./true_values_arr.npy", allow_pickle=False)
true_values


array([-9.21893793e-01, -9.20650119e-01, -9.19397542e-01, -9.18136046e-01,
       -9.16865618e-01, -9.15586241e-01, -9.14297900e-01, -9.13000582e-01,
       -9.11694271e-01, -9.10378953e-01, -9.09054614e-01, -9.07721240e-01,
       -9.06378818e-01, -9.05027334e-01, -9.03666775e-01, -9.02297129e-01,
       -9.00918381e-01, -8.99530521e-01, -8.98133535e-01, -8.96727411e-01,
       -8.95312138e-01, -8.93887703e-01, -8.92454096e-01, -8.91011306e-01,
       -8.89559320e-01, -8.88098129e-01, -8.86627722e-01, -8.85148088e-01,
       -8.83659218e-01, -8.82161101e-01, -8.80653729e-01, -8.79137091e-01,
       -8.77611178e-01, -8.76075981e-01, -8.74531493e-01, -8.72977703e-01,
       -8.71414604e-01, -8.69842188e-01, -8.68260446e-01, -8.66669372e-01,
       -8.65068959e-01, -8.63459197e-01, -8.61840082e-01, -8.60211606e-01,
       -8.58573763e-01, -8.56926546e-01, -8.55269950e-01, -8.53603969e-01,
       -8.51928596e-01, -8.50243828e-01, -8.48549659e-01, -8.46846084e-01,
       -8.45133099e-01, -

In [126]:
fig = go.Figure()
s = list(range(1000))
fig.add_trace(
    go.Scatter(x=[i + 1 for i in s], y=true_values, mode="lines", name="true values")
)
fig.add_trace(
    go.Scatter(
        x=[i + 1 for i in s],
        y=[agent.predict(i) for i in s],
        mode="lines",
        name="monte-carlo prediction",
    )
)
fig.show()
