In [36]:
import numpy as np

# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple, cast, Dict
import plotly.graph_objects as go
from copy import deepcopy
import math


In [37]:

np.random.seed(0)
env = gym.make('gym_cliffwalking:cliffwalking-v0')
env.seed(0)

obs = env.reset()


In [38]:
obs

0

In [39]:
Position = Literal[
    0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    11,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    20,
    21,
    22,
    23,
    24,
    25,
    26,
    27,
    28,
    29,
    30,
    31,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
]
State = Position
Observation = Position

# 0: Right; 1: Down; 2: Left; 3: Up
Action = Literal[0, 1, 2, 3]
Cols = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
Rows = Literal[0, 1, 2, 3]
Coordination = Tuple[Rows, Cols]
Reward = float
Step = Tuple[State, Action, Reward]
Episode = List[Step]


In [47]:
all_states = [
    0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    11,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    20,
    21,
    22,
    23,
    24,
    25,
    26,
    27,
    28,
    29,
    30,
    31,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
]

all_actions = [0, 1, 2, 3]
allowed_action: List[List[Action]] = [[0, 1, 2, 3] for _ in range(len(all_states))]
cols_per_row = 12
nums_of_state_action_pair = len(all_actions) * len(all_states)
terminal_state: Literal[11] = 11


def get_pos_from_cor(cor: Coordination) -> Position:
    row, col = cor
    return row * cols_per_row + col


def get_cor_from_pos(pos: Position) -> Coordination:
    row = pos // cols_per_row
    col = pos % cols_per_row
    return (row, col)


def get_pos_from_state(s: State) -> Position:
    assert s in all_states, "wrong state encountered"
    return s


def get_state_from_pos(idx: int) -> State:
    s = cast(State, idx)
    assert s in all_states, "wrong state encountered"
    return s


def get_pos_from_state_action(state: State, action: Action) -> Position:
    assert state in all_states, "wrong state encountered"
    assert action in all_actions, "wrong action encountered"

    return cast(Position, state * len(all_actions) + action)


def get_state_action_from_pos(pos: Position) -> Tuple[State, Action]:
    state = pos // len(all_states)
    action = pos % len(all_states)

    assert state in all_states, "wrong state encountered"
    assert action in all_actions, "wrong action encountered"

    return (cast(State, state), cast(Action, action))


In [41]:
class Agent:
    def __init__(self, env: gym.Env, alpha: float, sigma: float):
        self.env = env
        self.alpha = alpha
        self.sigma = sigma

        self.clear()

    def take_action(self, s: State) -> Action:
        # act = self.pai[get_idx_from_state(s)]
        act = np.random.choice(all_actions)
        return act

    def take_random_action(self, s: State) -> Dict[Action, float]:
        return {
            cast(Action, "r"): 0.25,
            cast(Action, "d"): 0.25,
            cast(Action, "l"): 0.25,
            cast(Action, "u"): 0.25,
        }

    def dynamic(self, s: State, a: Action) -> Dict[Tuple[State, Reward], float]:
        match s:
            case 0:
                match a:
                    case "r":
                        return {(0, -100): 1}
                    case "d":
                        return {(0, -1): 1}
                    case "l":
                        return {(0, -1): 1}
                    case "u":
                        return {(12, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 12:
                match a:
                    case "r":
                        return {(13, -1): 1}
                    case "d":
                        return {(0, -1): 1}
                    case "l":
                        return {(12, -1): 1}
                    case "u":
                        return {(24, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 13:
                match a:
                    case "r":
                        return {(14, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(12, -1): 1}
                    case "u":
                        return {(25, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 14:
                match a:
                    case "r":
                        return {(15, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(13, -1): 1}
                    case "u":
                        return {(26, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')

            case 15:
                match a:
                    case "r":
                        return {(16, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(14, -1): 1}
                    case "u":
                        return {(27, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 16:
                match a:
                    case "r":
                        return {(17, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(15, -1): 1}
                    case "u":
                        return {(28, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 17:
                match a:
                    case "r":
                        return {(18, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(16, -1): 1}
                    case "u":
                        return {(29, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 18:
                match a:
                    case "r":
                        return {(19, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(17, -1): 1}
                    case "u":
                        return {(30, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 19:
                match a:
                    case "r":
                        return {(20, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(18, -1): 1}
                    case "u":
                        return {(31, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 20:
                match a:
                    case "r":
                        return {(21, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(19, -1): 1}
                    case "u":
                        return {(32, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 21:
                match a:
                    case "r":
                        return {(22, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(20, -1): 1}
                    case "u":
                        return {(33, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 22:
                match a:
                    case "r":
                        return {(23, -1): 1}
                    case "d":
                        return {(0, -100): 1}
                    case "l":
                        return {(21, -1): 1}
                    case "u":
                        return {(34, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 23:
                match a:
                    case "r":
                        return {(23, -1): 1}
                    case "d":
                        return {(11, -1): 1}
                    case "l":
                        return {(22, -1): 1}
                    case "u":
                        return {(35, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 24:
                match a:
                    case "r":
                        return {(25, -1): 1}
                    case "d":
                        return {(12, -1): 1}
                    case "l":
                        return {(24, -1): 1}
                    case "u":
                        return {(36, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 25:
                match a:
                    case "r":
                        return {(26, -1): 1}
                    case "d":
                        return {(13, -1): 1}
                    case "l":
                        return {(24, -1): 1}
                    case "u":
                        return {(37, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 26:
                match a:
                    case "r":
                        return {(27, -1): 1}
                    case "d":
                        return {(14, -1): 1}
                    case "l":
                        return {(25, -1): 1}
                    case "u":
                        return {(38, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 27:
                match a:
                    case "r":
                        return {(28, -1): 1}
                    case "d":
                        return {(15, -1): 1}
                    case "l":
                        return {(26, -1): 1}
                    case "u":
                        return {(39, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 28:
                match a:
                    case "r":
                        return {(29, -1): 1}
                    case "d":
                        return {(16, -1): 1}
                    case "l":
                        return {(27, -1): 1}
                    case "u":
                        return {(40, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 29:
                match a:
                    case "r":
                        return {(30, -1): 1}
                    case "d":
                        return {(17, -1): 1}
                    case "l":
                        return {(28, -1): 1}
                    case "u":
                        return {(41, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')

            case 30:
                match a:
                    case "r":
                        return {(31, -1): 1}
                    case "d":
                        return {(18, -1): 1}
                    case "l":
                        return {(29, -1): 1}
                    case "u":
                        return {(42, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 31:
                match a:
                    case "r":
                        return {(32, -1): 1}
                    case "d":
                        return {(19, -1): 1}
                    case "l":
                        return {(30, -1): 1}
                    case "u":
                        return {(43, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 32:
                match a:
                    case "r":
                        return {(33, -1): 1}
                    case "d":
                        return {(20, -1): 1}
                    case "l":
                        return {(31, -1): 1}
                    case "u":
                        return {(44, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 33:
                match a:
                    case "r":
                        return {(34, -1): 1}
                    case "d":
                        return {(21, -1): 1}
                    case "l":
                        return {(32, -1): 1}
                    case "u":
                        return {(45, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 34:
                match a:
                    case "r":
                        return {(35, -1): 1}
                    case "d":
                        return {(22, -1): 1}
                    case "l":
                        return {(33, -1): 1}
                    case "u":
                        return {(46, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 35:
                match a:
                    case "r":
                        return {(35, -1): 1}
                    case "d":
                        return {(23, -1): 1}
                    case "l":
                        return {(34, -1): 1}
                    case "u":
                        return {(47, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 36:
                match a:
                    case "r":
                        return {(37, -1): 1}
                    case "d":
                        return {(24, -1): 1}
                    case "l":
                        return {(36, -1): 1}
                    case "u":
                        return {(36, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 37:
                match a:
                    case "r":
                        return {(38, -1): 1}
                    case "d":
                        return {(25, -1): 1}
                    case "l":
                        return {(36, -1): 1}
                    case "u":
                        return {(37, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 38:
                match a:
                    case "r":
                        return {(39, -1): 1}
                    case "d":
                        return {(26, -1): 1}
                    case "l":
                        return {(37, -1): 1}
                    case "u":
                        return {(38, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 39:
                match a:
                    case "r":
                        return {(40, -1): 1}
                    case "d":
                        return {(27, -1): 1}
                    case "l":
                        return {(38, -1): 1}
                    case "u":
                        return {(39, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 40:
                match a:
                    case "r":
                        return {(41, -1): 1}
                    case "d":
                        return {(28, -1): 1}
                    case "l":
                        return {(39, -1): 1}
                    case "u":
                        return {(40, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 41:
                match a:
                    case "r":
                        return {(42, -1): 1}
                    case "d":
                        return {(29, -1): 1}
                    case "l":
                        return {(40, -1): 1}
                    case "u":
                        return {(41, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 42:
                match a:
                    case "r":
                        return {(43, -1): 1}
                    case "d":
                        return {(30, -1): 1}
                    case "l":
                        return {(41, -1): 1}
                    case "u":
                        return {(42, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 43:
                match a:
                    case "r":
                        return {(44, -1): 1}
                    case "d":
                        return {(31, -1): 1}
                    case "l":
                        return {(42, -1): 1}
                    case "u":
                        return {(43, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 44:
                match a:
                    case "r":
                        return {(45, -1): 1}
                    case "d":
                        return {(32, -1): 1}
                    case "l":
                        return {(43, -1): 1}
                    case "u":
                        return {(44, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 45:
                match a:
                    case "r":
                        return {(46, -1): 1}
                    case "d":
                        return {(33, -1): 1}
                    case "l":
                        return {(44, -1): 1}
                    case "u":
                        return {(45, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 46:
                match a:
                    case "r":
                        return {(47, -1): 1}
                    case "d":
                        return {(34, -1): 1}
                    case "l":
                        return {(45, -1): 1}
                    case "u":
                        return {(46, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case 47:
                match a:
                    case "r":
                        return {(47, -1): 1}
                    case "d":
                        return {(35, -1): 1}
                    case "l":
                        return {(46, -1): 1}
                    case "u":
                        return {(47, -1): 1}
                    case _:
                        raise ValueError(
                            f'unexpected action encountered in dynamic: {a}')
            case _:
                raise ValueError(
                    f'unexpected state encountered in dynamic: {s}')

    def reset(self):
        self.end: bool = False

        pos = self.env.reset()
        self.current_state: State = get_state_from_pos(pos)
        self.current_episode: Episode = []

    def dp_evaluate(self):

    def clear(self):
        self.reset()

        # self.pai: List[Action] = [
        #     np.random.choice(all_actions) for _ in range(len(all_states))
        # ]

        self.V = cast(
            List[float], np.random.random(size=nums_of_state_action_pair))
        self.Q = cast(
            List[float], np.random.random(size=nums_of_state_action_pair)
        )
        for act in allowed_action[terminal_state]:
            self.Q[get_pos_from_state_action(terminal_state, act)] = 0

    def close(self):
        self.env.close()

    def step(self) -> Tuple[Observation, bool]:
        assert not self.end, "cannot step on a ended agent"

        act = self.take_action(self.current_state)
        (obs, reward, done, info) = self.env.step(act)

        self.current_episode.append((self.current_state, act, reward))

        obs = cast(Observation, obs)
        self.current_state = obs

        if done:
            self.end = True
            return obs, True
        return obs, False


In [42]:
TOTAL_RUNS = 1000

agent = Agent(env, 1, 0.5)

for i in range(TOTAL_RUNS):
    start_pos = agent.reset()

    while True:
        (obs, done) = agent.step()

        if done:
            break


agent.close()


In [48]:
get_cor_from_pos(get_pos_from_state(cast(State, 23)))


(1, 11)