In [70]:
import numpy as np
# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple
import plotly.graph_objects as go

In [71]:
# env = BlackjackEnv()
np.random.seed(0)
env = gym.make('Blackjack-v1')
env.seed(0)

obs = env.reset()


In [72]:
obs

(18, 1, False)

In [73]:
MySum = Literal[12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
DealerShowing = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
UsableAce = bool

Values = List[float]

State = Tuple[MySum, DealerShowing, UsableAce]
Action = Literal[0, 1]
Reward = Literal[-1, 0, 1]
EPISODE = List[Tuple[State, Action, Reward]]

TOTAL_NUM_OF_STATES = 200

In [74]:
def get_state_index(s: State) -> int:
    my_sum, dealer_show, usable_ace = s
    return int(usable_ace) + (dealer_show - 1) * 2 + (my_sum - 12) * 20


def get_state_from_index(i: int) -> State:
    my_sum = i // 20 + 12
    dealer_show = (i % 20) // 2 + 1
    usable_ace = (i % 20 % 2) 

    return (my_sum, dealer_show, usable_ace)


In [75]:
class Agent:
    def take_action(self, s: State) -> Literal[0, 1]:
        my_sum, deal_show, usable_ace = s
        assert my_sum <= 21, f"bad my_sum feed into agent: {my_sum}"
        return 0 if my_sum == 20 or my_sum == 21 else 1
        # return 0 if my_sum == 20 else 1

    def state_exists(self, s: State, ls: List[State]):
        for _s in ls:
            if s[0] == _s[0] and s[1] == _s[1] and s[2] == _s[2]:
                return True

        return False

    def predict(
        self, episodes: List[EPISODE], gamma: float
    ) -> Tuple[Values, List[List[Reward]]]:
        V = np.random.random(TOTAL_NUM_OF_STATES)
        R = [[] for _ in range(TOTAL_NUM_OF_STATES)]

        for episode in episodes:
            G = 0
            l = len(episode)
            for i, v in enumerate(reversed(episode)):
                state, action, reward = v
                G = gamma * G + reward
                if not self.state_exists(state, episode[: l - i - 1]):
                    idx = get_state_index(state)
                    R[idx].append(G)
                    V[idx] = np.average(R[idx])

        return (V, R)


In [76]:
TOTAL_EPISODES = 10_000

agent = Agent()

result: List[Tuple[bool, int]] = []


episodes: List[EPISODE] = []

for i in range(TOTAL_EPISODES):
    episode: EPISODE = []
    (my_sum, dealer_showup, usable_ace) = env.reset()

    assert my_sum <= 21, f"bad sum appears in init state: {my_sum}"

    print(f"episode {i}: {my_sum}, {dealer_showup}, {usable_ace}")

    while True:
        act = agent.take_action((my_sum, dealer_showup, usable_ace))

        print(f"take action: {'hit' if act == 1 else 'stick'}")

        (obs, rwd, done, info) = env.step(act)
        episode.append(((my_sum, dealer_showup, usable_ace), act, rwd))

        (my_sum, dealer_showup, usable_ace) = obs

        print(f"after taking action: {obs}, {rwd}, {done}, {info}")

        if done:
            print(f"iteration {i} end, reward: {rwd}")
            result.append((rwd > 0, my_sum))
            break
    episodes.append(episode)


env.close()


episode 0: 14, 10, False
take action: hit
after taking action: (20, 10, False), 0.0, False, {}
take action: stick
after taking action: (20, 10, False), 1.0, True, {}
iteration 0 end, reward: 1.0
episode 1: 19, 10, False
take action: hit
after taking action: (29, 10, False), -1.0, True, {}
iteration 1 end, reward: -1.0
episode 2: 4, 3, False
take action: hit
after taking action: (15, 3, True), 0.0, False, {}
take action: hit
after taking action: (17, 3, True), 0.0, False, {}
take action: hit
after taking action: (20, 3, True), 0.0, False, {}
take action: stick
after taking action: (20, 3, True), -1.0, True, {}
iteration 2 end, reward: -1.0
episode 3: 17, 5, False
take action: hit
after taking action: (26, 5, False), -1.0, True, {}
iteration 3 end, reward: -1.0
episode 4: 21, 10, True
take action: stick
after taking action: (21, 10, True), 1.0, True, {}
iteration 4 end, reward: 1.0
episode 5: 10, 4, False
take action: hit
after taking action: (15, 4, False), 0.0, False, {}
take action: h

In [77]:
result

[(True, 20),
 (False, 29),
 (False, 20),
 (False, 26),
 (True, 21),
 (True, 20),
 (False, 23),
 (False, 24),
 (False, 24),
 (False, 28),
 (True, 21),
 (False, 25),
 (False, 20),
 (False, 20),
 (True, 21),
 (False, 26),
 (False, 26),
 (False, 24),
 (False, 27),
 (False, 20),
 (False, 24),
 (True, 20),
 (True, 21),
 (False, 24),
 (False, 26),
 (False, 25),
 (False, 22),
 (False, 21),
 (False, 28),
 (False, 23),
 (False, 23),
 (False, 21),
 (False, 23),
 (True, 21),
 (False, 22),
 (False, 23),
 (False, 24),
 (False, 24),
 (True, 21),
 (False, 22),
 (False, 23),
 (False, 23),
 (False, 26),
 (False, 29),
 (False, 26),
 (False, 24),
 (True, 20),
 (False, 29),
 (True, 21),
 (True, 20),
 (True, 21),
 (True, 21),
 (True, 20),
 (False, 24),
 (False, 26),
 (True, 21),
 (True, 20),
 (True, 20),
 (False, 24),
 (False, 23),
 (True, 21),
 (False, 22),
 (True, 20),
 (False, 25),
 (False, 22),
 (False, 27),
 (True, 21),
 (True, 21),
 (False, 25),
 (False, 25),
 (False, 24),
 (False, 23),
 (True, 21),
 

In [78]:
episodes

[[((14, 10, False), 1, 0.0), ((20, 10, False), 0, 1.0)],
 [((19, 10, False), 1, -1.0)],
 [((4, 3, False), 1, 0.0),
  ((15, 3, True), 1, 0.0),
  ((17, 3, True), 1, 0.0),
  ((20, 3, True), 0, -1.0)],
 [((17, 5, False), 1, -1.0)],
 [((21, 10, True), 0, 1.0)],
 [((10, 4, False), 1, 0.0),
  ((15, 4, False), 1, 0.0),
  ((20, 4, False), 0, 1.0)],
 [((6, 10, False), 1, 0.0), ((13, 10, False), 1, -1.0)],
 [((14, 10, False), 1, -1.0)],
 [((18, 10, False), 1, -1.0)],
 [((7, 10, False), 1, 0.0),
  ((9, 10, False), 1, 0.0),
  ((18, 10, False), 1, -1.0)],
 [((7, 4, False), 1, 0.0), ((11, 4, False), 1, 0.0), ((21, 4, False), 0, 1.0)],
 [((16, 5, False), 1, -1.0)],
 [((10, 2, False), 1, 0.0), ((20, 2, False), 0, 0.0)],
 [((15, 10, False), 1, 0.0), ((20, 10, False), 0, 0.0)],
 [((14, 5, False), 1, 0.0), ((21, 5, False), 0, 1.0)],
 [((8, 10, False), 1, 0.0), ((18, 10, False), 1, -1.0)],
 [((18, 10, True), 1, 0.0),
  ((13, 10, False), 1, 0.0),
  ((16, 10, False), 1, -1.0)],
 [((14, 10, False), 1, -1.0)],

In [79]:
win_times = np.sum(np.fromiter((r[0] for r in result), dtype="int"))
print(f"win: {win_times}")
print(f"win rate: {win_times/TOTAL_EPISODES}")


win: 3030
win rate: 0.303


In [80]:
(V, R) = agent.predict(episodes, 1)


In [81]:
V = [(get_state_from_index(i), v) for (i, v) in enumerate(V)]


In [82]:
V

[((12, 1, 0), -0.648936170212766),
 ((12, 1, 1), -0.3333333333333333),
 ((12, 2, 0), -0.504950495049505),
 ((12, 2, 1), 0.0),
 ((12, 3, 0), -0.5638297872340425),
 ((12, 3, 1), -0.6),
 ((12, 4, 0), -0.575),
 ((12, 4, 1), -1.0),
 ((12, 5, 0), -0.5858585858585859),
 ((12, 5, 1), -1.0),
 ((12, 6, 0), -0.5729166666666666),
 ((12, 6, 1), 1.0),
 ((12, 7, 0), -0.42696629213483145),
 ((12, 7, 1), -0.5),
 ((12, 8, 0), -0.5352112676056338),
 ((12, 8, 1), -0.3333333333333333),
 ((12, 9, 0), -0.5116279069767442),
 ((12, 9, 1), -0.42857142857142855),
 ((12, 10, 0), -0.5949720670391061),
 ((12, 10, 1), -0.55),
 ((13, 1, 0), -0.6702127659574468),
 ((13, 1, 1), -0.2),
 ((13, 2, 0), -0.6732673267326733),
 ((13, 2, 1), -0.2),
 ((13, 3, 0), -0.6222222222222222),
 ((13, 3, 1), -0.25),
 ((13, 4, 0), -0.5851063829787234),
 ((13, 4, 1), -0.3333333333333333),
 ((13, 5, 0), -0.5625),
 ((13, 5, 1), -0.45454545454545453),
 ((13, 6, 0), -0.5643564356435643),
 ((13, 6, 1), -0.5),
 ((13, 7, 0), -0.5617977528089888),

In [83]:
fig = go.Figure(
    data=[
        go.Mesh3d(
            x=[v[0][0] for v in V if v[0][2] % 2 == 0],
            y=[v[0][1] for v in V if v[0][2] % 2 == 0],
            z=[v[1] for v in V if v[0][2] % 2 == 0],
            opacity=0.5,
            color="rgba(244,22,100,0.6)",
        )
    ]
)

fig.update_layout(
    scene=dict(
        xaxis=dict(
            nticks=4,
            range=[10, 25],
        ),
        yaxis=dict(
            nticks=4,
            range=[0, 12],
        ),
        zaxis=dict(
            nticks=4,
            range=[-1.2, 1.2],
        ),
    ),
    width=700,
)

fig.show()
