In [53]:
import numpy as np

# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple, cast, Dict, Optional, Callable, Protocol, Union
import plotly.graph_objects as go
from copy import deepcopy
import math
import sys
import plotly.express as px

import io


In [54]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)


In [55]:
N_STATES = 1000
EMPTY_MOVE = 0
P_LEFT = 0.5
R_STEP = 0
R_LEFT = -1
R_RIGHT = 1
SHIFT = 100


class RandomWalk:
    def __init__(self, n_states=None):
        self.n_states = N_STATES if n_states is None else n_states
        self.reset()
        self.r_l = R_LEFT
        self.r_r = R_RIGHT

    def sample_shift(self):
        return int(
            np.sign(np.random.random() - P_LEFT) * np.random.randint(1, SHIFT + 1)
        )

    def step(self, action):
        assert (
            self.state >= 0 and self.state < self.n_states
        ), f"unexpected state encouter: {self.state}"

        shift = self.sample_shift()
        new_state = self.state + shift
        if not (0 <= new_state < self.n_states):
            r = self.r_r if (new_state >= self.n_states) else self.r_l

            r_s = self.n_states if new_state >= self.n_states else -1
            return r_s, r, True, {}

        self.state = new_state
        return self.state, R_STEP, False, {}

    def reset(self):
        self.state = self.n_states // 2
        return self.state

    def seed(self, seed):
        pass


In [56]:
env = RandomWalk()
env.seed(RANDOM_SEED)
env


<__main__.RandomWalk at 0x14f693df0>

In [57]:
env.reset()

500

In [58]:
State = int
Observation = State

# 0: Left, 1: Right
Action = Literal[0, 1]
Reward = float
Step = Tuple[State, Optional[Action], Optional[Reward]]
Episode = List[Step]

all_states: List[State] = list(range(1000))
all_actions: List[Action] = [0, 1]
nums_of_all_state = len(all_states)
nums_of_all_state_action = len(all_states) * len(all_actions)
allowed_actions: List[List[Action]] = [all_actions for _ in range(nums_of_all_state)]


Feature = np.ndarray
Weight = np.ndarray


In [59]:
nums_of_features = int(nums_of_all_state / 100)

In [60]:
Algorithm = Literal["state_agg", "fourier"]


class MonteCarlo:
    pass


class TD:
    pass


class TDN:
    def __init__(self, n: int):
        self.n = n


class Fourier:
    def __init__(self, order: int):
        self.order = order


class StateAgg:
    def __init__(self, n_groups: int):
        self.n_groups = n_groups


class Linear:
    def __init__(self):
        pass


class Policy:
    def __init__(
        self,
        kind: Literal["mc", "td", "tdn"],
        detail: Union[MonteCarlo, TD, TDN],
        predict_algorithm: Linear,
        feature_algorithm: Union[StateAgg, Fourier],
    ):
        self.kind = kind
        self.detail = detail
        self.predict_algorithm = predict_algorithm
        self.feature_algorithm = feature_algorithm

        self.n_of_omega = (
            self.feature_algorithm.n_groups
            if isinstance(self.feature_algorithm, StateAgg)
            else self.feature_algorithm.order + 1
            if isinstance(self.feature_algorithm, Fourier)
            else -1
        )
        assert self.n_of_omega != -1, f"bad n_of_omage encountered: {self.n_of_omega}"

    def to_feature(self, s: State) -> Feature:
        assert 0 <= s < len(all_states), f"unexpected state encounter: {s}"
        if isinstance(self.feature_algorithm, StateAgg):
            n_groups = self.feature_algorithm.n_groups

            v = [0 for _ in range(n_groups)]
            states_per_group = math.ceil(nums_of_all_state / n_groups)

            v[s // states_per_group] = 1

            return np.asarray(v)

        elif isinstance(self.feature_algorithm, Fourier):
            order = self.feature_algorithm.order
            s_c = s / nums_of_all_state
            return np.asarray([np.cos(o * np.pi * s_c) for o in range(order + 1)])

        else:
            raise ValueError(
                f"unexpected feature transformation encountered: {type(self.feature_algorithm)}"
            )

    def predict(self, s: State, w: Weight) -> float:
        assert -1 <= s <= len(all_states), f"unexpected state encounter: {s}"

        if s == -1 or s == len(all_states):
            return 0

        if isinstance(self.predict_algorithm, Linear):
            return np.inner(self.to_feature(s), w)

        else:
            raise ValueError(
                f"unexpected predict mode encountered: {self.kind}, {type(self.predict_algorithm)}"
            )

    def gradient(self, s: State, w: Weight) -> np.ndarray:
        if isinstance(self.predict_algorithm, Linear):
            return self.to_feature(s)
        else:
            raise ValueError(
                f"unexpected gradient mode encountered: {type(self.predict_algorithm)}"
            )


In [61]:
class Agent:
    def __init__(
        self,
        env: gym.Env,
        alpha: float,
        policy: Policy,
        gamma: float = 1.0,
    ):
        self.alpha = alpha
        self.env = env
        self.policy = policy
        self.gamma = gamma
        self.clear()

    def reset(self):
        self.cur_state: State = self.env.reset()
        self.end = False
        self.episode: Episode = []

    def clear(self):
        self.reset()

        self.omega = np.asarray(
            [np.random.random() for _ in range(self.policy.n_of_omega)]
        )
        self.episodes: List[Episode] = []

    def step(self) -> Tuple[Observation, bool]:
        assert not self.end, "cannot step on a ended agent"

        (new_state, rwd, stop, _) = self.env.step(0)

        self.episode.append((self.cur_state, 0, rwd))

        self.cur_state = new_state

        if self.policy.kind == "td":
            self.td_evaluate(self.gamma)
        if self.policy.kind == "td_n":
            self.td_n_evaluate(4, self.gamma)

        if stop:
            self.episode.append((self.cur_state, None, None))
            self.episodes.append(self.episode)
            self.episode = []
            self.end = True
            if self.policy.kind == "mc":
                self.mc_evaluate()

        return (cast(State, new_state), stop)

    def td_n_evaluate(self, n: int, gamma: float):
        this_s = self.cur_state
        history = self.episode[-(1 + n) :]
        if len(history) != (1 + n):
            return
        (old_s, _, r) = history[0]

        rwd = np.sum(
            [cast(Reward, r) * gamma ** (i) for (i, (_, _, r)) in enumerate(history)]
        )

        # for (i, (s, _, r)) in enumerate(epi[:-1]):
        omega = self.omega
        # next_s = (epi[i + 1])[0]
        self.omega = omega + (
            self.alpha
            * (
                rwd
                + (gamma ** (n + 1)) * self.policy.predict(this_s, self.omega)
                - self.policy.predict(old_s, self.omega)
            )
            * self.policy.gradient(old_s, self.omega)
        )

    def td_evaluate(self, gamma: float):
        return self.td_n_evaluate(0, gamma)

    def mc_evaluate(self):
        epi = self.episodes[-1]

        rwd = cast(Reward, epi[-2][2])
        assert (type(rwd) is float or type(rwd) is int) and not math.isnan(
            rwd
        ), f"unexpected rwd encountered: {rwd}"

        for (state, action, reward) in epi[:-1]:
            omega = self.omega
            self.omega = omega + (
                self.alpha
                * (rwd - self.policy.predict(state, self.omega))
                * self.policy.gradient(state, self.omega)
            )

    def predict(self, s: State) -> float:
        return self.policy.predict(s, self.omega)


In [62]:
TOTAL_EPISODES = 100_000

agent = Agent(
    cast(gym.Env, env),
    5e-5,
    Policy(
        kind="mc",
        detail=MonteCarlo(),
        predict_algorithm=Linear(),
        feature_algorithm=Fourier(10),
    ),
)

for _ in range(TOTAL_EPISODES):
    agent.reset()
    end = False
    while not end:
        _, end = agent.step()


In [63]:
agent.episodes[:20]


[[(500, 0, 0),
  (589, 0, 0),
  (563, 0, 0),
  (553, 0, 0),
  (634, 0, 0),
  (714, 0, 0),
  (797, 0, 0),
  (747, 0, 0),
  (727, 0, 0),
  (767, 0, 0),
  (777, 0, 0),
  (810, 0, 0),
  (834, 0, 0),
  (758, 0, 0),
  (787, 0, 0),
  (788, 0, 0),
  (782, 0, 0),
  (800, 0, 0),
  (843, 0, 0),
  (841, 0, 0),
  (883, 0, 0),
  (895, 0, 0),
  (803, 0, 0),
  (903, 0, 0),
  (860, 0, 0),
  (791, 0, 0),
  (839, 0, 0),
  (835, 0, 0),
  (888, 0, 0),
  (872, 0, 0),
  (813, 0, 0),
  (799, 0, 0),
  (848, 0, 0),
  (806, 0, 0),
  (902, 0, 0),
  (903, 0, 0),
  (940, 0, 0),
  (846, 0, 0),
  (889, 0, 0),
  (911, 0, 0),
  (912, 0, 0),
  (868, 0, 0),
  (844, 0, 0),
  (745, 0, 0),
  (840, 0, 0),
  (757, 0, 0),
  (778, 0, 0),
  (750, 0, 0),
  (809, 0, 0),
  (820, 0, 0),
  (808, 0, 0),
  (889, 0, 0),
  (890, 0, 0),
  (843, 0, 0),
  (900, 0, 0),
  (822, 0, 0),
  (819, 0, 0),
  (720, 0, 0),
  (679, 0, 0),
  (775, 0, 0),
  (842, 0, 0),
  (910, 0, 0),
  (925, 0, 0),
  (857, 0, 0),
  (944, 0, 1),
  (1000, None, None)],
 [

In [64]:
agent.omega

array([-8.14783073e-03, -7.35545929e-01, -4.70655038e-03, -8.12773836e-02,
       -5.25338114e-04, -3.40245449e-02, -5.86849945e-03, -2.94476691e-02,
        8.73368805e-03, -6.62441574e-03, -4.47526655e-03])

In [65]:
true_values = np.load("./true_values_arr.npy", allow_pickle=False)
true_values


array([-9.21893793e-01, -9.20650119e-01, -9.19397542e-01, -9.18136046e-01,
       -9.16865618e-01, -9.15586241e-01, -9.14297900e-01, -9.13000582e-01,
       -9.11694271e-01, -9.10378953e-01, -9.09054614e-01, -9.07721240e-01,
       -9.06378818e-01, -9.05027334e-01, -9.03666775e-01, -9.02297129e-01,
       -9.00918381e-01, -8.99530521e-01, -8.98133535e-01, -8.96727411e-01,
       -8.95312138e-01, -8.93887703e-01, -8.92454096e-01, -8.91011306e-01,
       -8.89559320e-01, -8.88098129e-01, -8.86627722e-01, -8.85148088e-01,
       -8.83659218e-01, -8.82161101e-01, -8.80653729e-01, -8.79137091e-01,
       -8.77611178e-01, -8.76075981e-01, -8.74531493e-01, -8.72977703e-01,
       -8.71414604e-01, -8.69842188e-01, -8.68260446e-01, -8.66669372e-01,
       -8.65068959e-01, -8.63459197e-01, -8.61840082e-01, -8.60211606e-01,
       -8.58573763e-01, -8.56926546e-01, -8.55269950e-01, -8.53603969e-01,
       -8.51928596e-01, -8.50243828e-01, -8.48549659e-01, -8.46846084e-01,
       -8.45133099e-01, -

In [66]:
fig = go.Figure()
s = list(range(1000))
fig.add_trace(
    go.Scatter(x=[i + 1 for i in s], y=true_values, mode="lines", name="true values")
)
fig.add_trace(
    go.Scatter(
        x=[i + 1 for i in s],
        y=[agent.predict(i) for i in s],
        mode="lines",
        name="monte-carlo prediction",
    )
)
fig.show()
