In [36]:

import numpy as np
# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple, cast
import plotly.graph_objects as go
from copy import deepcopy
import math

In [37]:

np.random.seed(0)
env = gym.make('gym_cliffwalking:cliffwalking-v0')
env.seed(0)

obs = env.reset()


In [38]:
obs

0

In [39]:
Position = Literal[
    0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    11,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    20,
    21,
    22,
    23,
    24,
    25,
    26,
    27,
    28,
    29,
    30,
    31,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
]
State = Position
Observation = Position

# 0: Right; 1: Down; 2: Left; 3: Up
Action = Literal[0, 1, 2, 3]
Cols = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
Rows = Literal[0, 1, 2, 3]
Coordination = Tuple[Rows, Cols]
Reward = float
Step = Tuple[State, Action, Reward]
Episode = List[Step]


In [47]:
all_states = [
    0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    11,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    20,
    21,
    22,
    23,
    24,
    25,
    26,
    27,
    28,
    29,
    30,
    31,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
]

all_actions = [0, 1, 2, 3]
allowed_action: List[List[Action]] = [[0, 1, 2, 3] for _ in range(len(all_states))]
cols_per_row = 12
nums_of_state_action_pair = len(all_actions) * len(all_states)
terminal_state: Literal[11] = 11


def get_pos_from_cor(cor: Coordination) -> Position:
    row, col = cor
    return row * cols_per_row + col


def get_cor_from_pos(pos: Position) -> Coordination:
    row = pos // cols_per_row
    col = pos % cols_per_row
    return (row, col)


def get_pos_from_state(s: State) -> Position:
    assert s in all_states, "wrong state encountered"
    return s


def get_state_from_pos(idx: int) -> State:
    s = cast(State, idx)
    assert s in all_states, "wrong state encountered"
    return s


def get_pos_from_state_action(state: State, action: Action) -> Position:
    assert state in all_states, "wrong state encountered"
    assert action in all_actions, "wrong action encountered"

    return cast(Position, state * len(all_actions) + action)


def get_state_action_from_pos(pos: Position) -> Tuple[State, Action]:
    state = pos // len(all_states)
    action = pos % len(all_states)

    assert state in all_states, "wrong state encountered"
    assert action in all_actions, "wrong action encountered"

    return (cast(State, state), cast(Action, action))


In [41]:
class Agent:
    def __init__(self, env: gym.Env, alpha: float, sigma: float):
        self.env = env
        self.alpha = alpha
        self.sigma = sigma

        self.clear()

    def take_action(self, s: State) -> Action:
        # act = self.pai[get_idx_from_state(s)]
        act = np.random.choice(all_actions)
        return act

    def reset(self):
        self.end: bool = False

        pos = self.env.reset()
        self.current_state: State = get_state_from_pos(pos)
        self.current_episode: Episode = []

    def clear(self):
        self.reset()

        self.pai: List[Action] = [
            np.random.choice(all_actions) for _ in range(len(all_states))
        ]

        self.Q: List[float] = cast(
            List[float], np.random.random(size=nums_of_state_action_pair)
        )
        for act in allowed_action[terminal_state]:
            self.Q[get_pos_from_state_action(terminal_state, act)] = 0

    def close(self):
        self.env.close()

    def step(self) -> Tuple[Observation, bool]:
        assert not self.end, "cannot step on a ended agent"

        act = self.take_action(self.current_state)
        (obs, reward, done, info) = self.env.step(act)

        self.current_episode.append((self.current_state, act, reward))

        obs = cast(Observation, obs)
        self.current_state = obs

        if done:
            self.end = True
            return obs, True
        return obs, False


In [42]:
TOTAL_RUNS = 1000

agent = Agent(env, 1, 0.5)

for i in range(TOTAL_RUNS):
    start_pos = agent.reset()

    while True:
        (obs, done) = agent.step()

        if done:
            break


agent.close()


In [48]:
get_cor_from_pos(get_pos_from_state(cast(State, 23)))


(1, 11)