In [132]:
import numpy as np
# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple
import plotly.graph_objects as go

In [133]:
# env = BlackjackEnv()
np.random.seed(0)
env = gym.make('Blackjack-v1')
env.seed(0)

obs = env.reset()


In [134]:
obs

(18, 1, False)

In [135]:
MySum = Literal[12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
DealerShowing = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
UsableAce = bool

Values = List[float]

State = Tuple[MySum, DealerShowing, UsableAce]
Action = Literal[0, 1]
Reward = Literal[-1, 0, 1]
EPISODE = List[Tuple[State, Action, Reward]]

TOTAL_NUM_OF_STATES = 200

In [136]:
def get_state_index(s: State) -> int:
    my_sum, dealer_show, usable_ace = s
    return int(usable_ace) + (dealer_show - 1) * 2 + (my_sum - 12) * 20


def get_state_from_index(i: int) -> State:
    my_sum = i // 20 + 12
    dealer_show = (i % 20) // 2 + 1
    usable_ace = (i % 20 % 2) 

    return (my_sum, dealer_show, usable_ace)


In [137]:
class Agent:
    def take_action(self, s: State) -> Literal[0, 1]:
        my_sum, deal_show, usable_ace = s
        assert my_sum <= 21, f"bad my_sum feed into agent: {my_sum}"
        return 0 if my_sum == 20 or my_sum == 21 else 1
        # return 0 if my_sum == 20 else 1

    def state_exists(self, s: State, ls: List[State]):
        for _s in ls:
            if s[0] == _s[0] and s[1] == _s[1] and s[2] == _s[2]:
                return True

        return False

    def predict(
        self, episodes: List[EPISODE], gamma: float
    ) -> Tuple[Values, List[List[Reward]]]:
        V = np.random.random(TOTAL_NUM_OF_STATES)
        R = [[] for _ in range(TOTAL_NUM_OF_STATES)]

        for episode in episodes:
            G = 0
            l = len(episode)
            for i, v in enumerate(reversed(episode)):
                state, action, reward = v
                G = gamma * G + reward
                if not self.state_exists(state, episode[: l - i - 1]):
                    idx = get_state_index(state)
                    R[idx].append(G)
                    V[idx] = np.average(R[idx])

        return (V, R)


In [138]:
TOTAL_EPISODES = 200_000

agent = Agent()

result: List[Tuple[bool, int]] = []


episodes: List[EPISODE] = []

for i in range(TOTAL_EPISODES):
    episode: EPISODE = []
    (my_sum, dealer_showup, usable_ace) = env.reset()

    assert my_sum <= 21, f"bad sum appears in init state: {my_sum}"

    # print(f"episode {i}: {my_sum}, {dealer_showup}, {usable_ace}")

    while True:
        act = agent.take_action((my_sum, dealer_showup, usable_ace))

        # print(f"take action: {'hit' if act == 1 else 'stick'}")

        (obs, rwd, done, info) = env.step(act)
        episode.append(((my_sum, dealer_showup, usable_ace), act, rwd))

        (my_sum, dealer_showup, usable_ace) = obs

        # print(f"after taking action: {obs}, {rwd}, {done}, {info}")

        if done:
            # print(f"iteration {i} end, reward: {rwd}")
            result.append((rwd > 0, my_sum))
            break
    episodes.append(episode)


env.close()


In [139]:
win_times = np.sum(np.fromiter((r[0] for r in result), dtype="int"))
print(f"win: {win_times}")
print(f"win rate: {win_times/TOTAL_EPISODES}")


win: 59337
win rate: 0.296685


In [140]:
(V, R) = agent.predict(episodes, 1)


In [141]:
V = [(get_state_from_index(i), v) for (i, v) in enumerate(V)]


In [142]:
fig = go.Figure(
    data=[
        go.Mesh3d(
            x=[v[0][0] for v in V if v[0][2] % 2 == 0],
            y=[v[0][1] for v in V if v[0][2] % 2 == 0],
            z=[v[1] for v in V if v[0][2] % 2 == 0],
            opacity=0.5,
            color="rgba(244,22,100,0.6)",
        )
    ]
)

fig.update_layout(
    title="no usable ace",
    scene=dict(
        xaxis=dict(
            nticks=4,
            title="my sum",
            range=[10, 25],
        ),
        yaxis=dict(
            title="dealer show",
            nticks=4,
            range=[0, 12],
        ),
        zaxis=dict(
            title="state value",
            nticks=4,
            range=[-1.2, 1.2],
        ),
    ),
    width=700,
)

fig.show()


In [143]:
fig = go.Figure(
    data=[
        go.Mesh3d(
            x=[v[0][0] for v in V if v[0][2] % 2 == 1],
            y=[v[0][1] for v in V if v[0][2] % 2 == 1],
            z=[v[1] for v in V if v[0][2] % 2 == 1],
            opacity=0.5,
            color="rgba(244,22,100,0.6)",
        )
    ]
)

fig.update_layout(
    title="has usable ace",
    scene=dict(
        xaxis=dict(
            nticks=4,
            title="my sum",
            range=[10, 25],
        ),
        yaxis=dict(
            title="dealer show",
            nticks=4,
            range=[0, 12],
        ),
        zaxis=dict(
            title="state value",
            nticks=4,
            range=[-1.2, 1.2],
        ),
    ),
    width=700,
)

fig.show()
