In [1]:
import numpy as np
# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple, cast
import plotly.graph_objects as go
from copy import deepcopy
import math

In [2]:
# env = BlackjackEnv()
np.random.seed(0)
env = gym.make('Blackjack-v1')
env.seed(0)

obs = env.reset()


In [3]:
obs

(18, 1, False)

In [4]:
MySum = Literal[12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
DealerShowing = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
UsableAce = bool

Observation = Tuple[MySum, DealerShowing, UsableAce]

Values = List[float]
Action_Values = List[float]

State = Tuple[MySum, DealerShowing, UsableAce]
Action = Literal[0, 1]
Reward = Literal[-1, 0, 1]
Policy = List[Action]
Step = Tuple[State, Action, Reward]
Episode = List[Step]

TOTAL_NUM_OF_STATES = 200
TOTAL_NUM_OF_STATE_ACTION_PAIR = 400


In [5]:
def get_state_index(s: State) -> int:
    my_sum, dealer_show, usable_ace = s
    idx = int(usable_ace) + (dealer_show - 1) * 2 + (my_sum - 12) * 20
    assert idx >= 0, "idx is illegal"
    return idx


def get_state_from_index(i: int) -> State:
    my_sum = cast(MySum, i // 20 + 12)
    dealer_show = cast(DealerShowing, (i % 20) // 2 + 1)
    usable_ace = cast(UsableAce, i % 20 % 2)
    assert my_sum >= 12 and my_sum <= 21, f"my_sum is invalid: {my_sum}"
    assert (
        dealer_show >= 1 and dealer_show <= 10
    ), f"dealer show is invalid: {dealer_show}"
    return (my_sum, dealer_show, usable_ace)


def get_state_index_with_action(s: State, a: Action) -> int:
    my_sum, dealer_show, usable_ace = s
    idx = a + int(usable_ace) * 2 + (dealer_show - 1) * 4 + (my_sum - 12) * 40
    assert idx >= 0, "idx is illegal"
    return idx


def get_state_from_index_with_action(i: int) -> Tuple[State, Action]:
    my_sum = cast(MySum, i // 40 + 12)
    dealer_show = cast(DealerShowing, (i % 40) // 4 + 1)
    usable_ace = cast(UsableAce, (i % 40 % 4) // 2)
    act = cast(Action, i % 40 % 4 % 2)
    assert my_sum >= 12 and my_sum <= 21, f"my_sum is invalid: {my_sum}"
    assert (
        dealer_show >= 1 and dealer_show <= 10
    ), f"dealer show is invalid: {dealer_show}"
    return ((my_sum, dealer_show, usable_ace), act)


def state_equal(s1: State, s2: State):
    return s1[0] == s2[0] and s1[1] == s2[1] and s1[2] == s2[2]


In [6]:
class Agent:
    def __init__(self, env: gym.Env, gamma: float, improve: bool = False):
        self.env = env
        self.gamma = gamma
        self.improve = improve

        self.clear()

    def close(self):
        self.env.close()

    def step(self, obs: Observation, random_act=False) -> Tuple[Observation, bool]:
        assert not self.end, "cannot step on a ended agent, reset it first"
        (my_sum, dealer_showup, usable_ace) = obs

        assert my_sum <= 21, f"bad sum appears in step: {my_sum}"

        act = (
            self.take_action((my_sum, dealer_showup, usable_ace))
            if not random_act
            else self.take_random_action()
        )

        (obs, rwd, done, info) = self.env.step(act)

        self.current_episode.append(
            ((my_sum, dealer_showup, usable_ace), act, cast(Reward, rwd))
        )

        (my_sum, dealer_showup, usable_ace) = obs

        if done:
            self.win_times += 1 if rwd > 0 else 0
            self.evaluate(self.improve)
            self.episodes.append(self.current_episode)
            self.current_episode = []
            self.end = True

        return ((my_sum, dealer_showup, usable_ace), done)

    def reset(self) -> Observation:
        assert (
            len(self.current_episode) == 0
        ), "when reset, current episode has some dirty elements, step it first or clear"

        # self.current_episode = []
        self.end = False

        while True:
            (my_sum, dealer_showup, usable_ace) = self.env.reset()
            assert my_sum <= 21, f"too large sum appears in reset: {my_sum}"
            if my_sum < 12:
                continue

            assert my_sum >= 12, f"too small sum appears in reset: {my_sum}"

            return (my_sum, dealer_showup, usable_ace)

    def clear(self):
        self.current_episode: Episode = []
        self.episodes: List[Episode] = [[] for _ in range(TOTAL_NUM_OF_STATES)]

        self.Q = [float("nan") for _ in range(TOTAL_NUM_OF_STATE_ACTION_PAIR)]
        self.q_returns = [[] for _ in range(TOTAL_NUM_OF_STATE_ACTION_PAIR)]

        self.pai = [1 for i in range(TOTAL_NUM_OF_STATES)]

        (my_sum, dealer_showup, usable_ace) = self.env.reset()
        assert my_sum <= 21, f"bad sum appears in init state: {my_sum}"

        self.win_times = 0
        self.end = True

    def take_action(self, s: State) -> Action:
        my_sum, deal_show, usable_ace = s
        assert my_sum <= 21, f"bad my_sum feed into agent: {my_sum}"

        new_act = self.pai[get_state_index((my_sum, deal_show, usable_ace))]
        # old_act = 0 if my_sum == 20 or my_sum == 21 else 1

        # assert old_act == new_act, "act is not same"
        return cast(Action, new_act)

    def take_random_action(self) -> Action:
        return np.random.choice([0, 1])

    def state_exists(self, s: State, es: List[Episode]):
        for e in es:
            (_s) = e[0]
            if s[0] == _s[0] and s[1] == _s[1] and s[2] == _s[2]:
                return True

        return False

    def state_action_exists(self, s: State, a: Action, es: Episode):
        for e in es:
            (_s, _a) = (e[0], e[1])
            if s[0] == _s[0] and s[1] == _s[1] and s[2] == _s[2] and _a == a:
                return True

        return False

    def evaluate(self, improve=False):
        episode = self.current_episode

        G = 0
        l = len(episode)
        for i, v in enumerate(reversed(episode)):
            state, action, reward = v
            G = self.gamma * G + reward
            if not self.state_action_exists(state, action, episode[: l - i - 1]):
                idx = get_state_index_with_action(state, action)
                state_idx = get_state_index(state)
                self.q_returns[idx].append(G)
                self.Q[idx] = np.mean(self.q_returns[idx])
                if improve:
                    # self.pai[state_idx] = np.argmax(
                    new_act = np.argmax(
                        [
                            self.Q[get_state_index_with_action(state, cast(Action, a))]
                            for a in [0, 1]
                        ]
                    )
                    # old_act = self.pai[state_idx]
                    self.pai[state_idx] = cast(int, new_act)


In [7]:
TOTAL_EPISODES = 400_000

agent = Agent(env, 1, True)


for i in range(TOTAL_EPISODES):
    (my_sum, dealer_showup, usable_ace) = agent.reset()

    start = True
    while True:
        (obs, done) = agent.step(
            (my_sum, dealer_showup, usable_ace), random_act=True if start else False
        )
        start = False

        (my_sum, dealer_showup, usable_ace) = obs

        if done:
            break


agent.close()


In [63]:
# win_times = np.sum(np.fromiter((r[0] for r in result), dtype="int"))
print(f"win: {agent.win_times}")
print(f"win rate: {agent.win_times/TOTAL_EPISODES}")


win: 134305
win rate: 0.3357625


In [64]:
(Q, R) = (agent.Q, agent.q_returns)

In [65]:
for r in R:
    assert len(r) > 0

for q in Q:
    assert not math.isnan(q)


In [66]:
# V = [(get_state_from_index(i), v) for (i, v) in enumerate(V)]
Q = [(get_state_from_index_with_action(i), v) for (i, v) in enumerate(Q)]
i = 0
V = []
while i < len(Q):
    q_act0 = Q[i]
    q_act1 = Q[i + 1]
    assert state_equal(
        q_act0[0][0], q_act1[0][0]
    ), "error encountered during Q collapse"
    state_value = (q_act0[0][0], np.max([q_act0[1], q_act1[1]]))
    V.append(state_value)
    i += 2


In [69]:
fig = go.Figure(
    data=[
        go.Mesh3d(
            x=[v[0][0] for v in V if v[0][2] % 2 == 0],
            y=[v[0][1] for v in V if v[0][2] % 2 == 0],
            z=[v[1] for v in V if v[0][2] % 2 == 0],
            opacity=0.5,
            color="rgba(244,22,100,0.6)",
        )
    ]
)

fig.update_layout(
    title="no usable ace",
    scene=dict(
        xaxis=dict(
            nticks=4,
            title="my sum",
            range=[10, 25],
        ),
        yaxis=dict(
            title="dealer show",
            nticks=4,
            range=[0, 12],
        ),
        zaxis=dict(
            title="state value",
            nticks=4,
            range=[-1, 1],
        ),
    ),
    width=700,
)

fig.show()


In [70]:
fig = go.Figure(
    data=[
        go.Mesh3d(
            x=[v[0][0] for v in V if v[0][2] % 2 == 1],
            y=[v[0][1] for v in V if v[0][2] % 2 == 1],
            z=[v[1] for v in V if v[0][2] % 2 == 1],
            opacity=0.5,
            color="rgba(244,22,100,0.6)",
        )
    ]
)

fig.update_layout(
    title="has usable ace",
    scene=dict(
        xaxis=dict(
            nticks=4,
            title="my sum",
            range=[10, 25],
        ),
        yaxis=dict(
            title="dealer show",
            nticks=4,
            range=[0, 12],
        ),
        zaxis=dict(
            title="state value",
            nticks=4,
            range=[-1, 1],
        ),
    ),
    width=700,
)

fig.show()


In [83]:
agent.pai[get_state_index((12, 2, False))]


1