In [8]:
import numpy as np
# from gym.envs.toy_text import BlackjackEnv
import gym
from typing import Literal, List, Tuple

In [9]:
# env = BlackjackEnv()
np.random.seed(0)
env = gym.make('Blackjack-v1')
env.seed(0)

obs = env.reset()


In [10]:
obs

(18, 1, False)

In [11]:
MySum = Literal[12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
DealerShowing = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
UsableAce = bool

State = Tuple[MySum, DealerShowing, UsableAce]
Action = Literal[0, 1]
Reward = Literal[-1, 0, 1]
EPISODE = List[Tuple[State, Action, Reward]]

TOTAL_NUM_OF_STATES = 200

In [12]:
def get_state_index(s: State) -> int:
    my_sum, dealer_show, usable_ace = s
    return int(usable_ace) + (dealer_show - 1) * 2 + (my_sum - 12) * 20


In [13]:
class Agent:
    def take_action(self, s: State) -> Literal[0, 1]:
        my_sum, deal_show, usable_ace = s
        assert my_sum <= 21, f"bad my_sum feed into agent: {my_sum}"
        return 0 if my_sum == 20 or my_sum == 21 else 1
        # return 0 if my_sum == 20 else 1

    def state_exists(self, s: State, ls: List[State]):
        for _s in ls:
            if s[0] == _s[0] and s[1] == _s[1] and s[2] == _s[2]:
                return True

        return False

    def predict(self, episodes: List[EPISODE], gamma: float):
        V = np.random.random(TOTAL_NUM_OF_STATES)
        R = [[] for _ in range(TOTAL_NUM_OF_STATES)]

        for episode in episodes:
            G = 0
            l = len(episode)
            for i, v in enumerate(reversed(episode)):
                state, action, reward = v
                G = gamma * G + reward
                if not self.state_exists(state, episode[: l - i - 1]):
                    idx = get_state_index(state)
                    R[idx].append(G)
                    V[idx] = np.average(R[idx])

        return (V, R)


In [14]:
TOTAL_EPISODES = 1000

agent = Agent()

result: List[Tuple[bool, int]] = []


episodes: List[EPISODE] = []

for i in range(TOTAL_EPISODES):
    episode: EPISODE = []
    (my_sum, dealer_showup, usable_ace) = env.reset()

    assert my_sum <= 21, f"bad sum appears in init state: {my_sum}"

    print(f"episode {i}: {my_sum}, {dealer_showup}, {usable_ace}")

    while True:
        act = agent.take_action((my_sum, dealer_showup, usable_ace))

        print(f"take action: {'hit' if act == 1 else 'stick'}")

        (obs, rwd, done, info) = env.step(act)
        episode.append(((my_sum, dealer_showup, usable_ace), act, rwd))

        (my_sum, dealer_showup, usable_ace) = obs

        print(f"after taking action: {obs}, {rwd}, {done}, {info}")

        if done:
            print(f"iteration {i} end, reward: {rwd}")
            result.append((rwd > 0, my_sum))
            break
    episodes.append(episode)


env.close()


episode 0: 14, 10, False
take action: hit
after taking action: (20, 10, False), 0.0, False, {}
take action: stick
after taking action: (20, 10, False), 1.0, True, {}
iteration 0 end, reward: 1.0
episode 1: 19, 10, False
take action: hit
after taking action: (29, 10, False), -1.0, True, {}
iteration 1 end, reward: -1.0
episode 2: 4, 3, False
take action: hit
after taking action: (15, 3, True), 0.0, False, {}
take action: hit
after taking action: (17, 3, True), 0.0, False, {}
take action: hit
after taking action: (20, 3, True), 0.0, False, {}
take action: stick
after taking action: (20, 3, True), -1.0, True, {}
iteration 2 end, reward: -1.0
episode 3: 17, 5, False
take action: hit
after taking action: (26, 5, False), -1.0, True, {}
iteration 3 end, reward: -1.0
episode 4: 21, 10, True
take action: stick
after taking action: (21, 10, True), 1.0, True, {}
iteration 4 end, reward: 1.0
episode 5: 10, 4, False
take action: hit
after taking action: (15, 4, False), 0.0, False, {}
take action: h

In [15]:
result

[(True, 20),
 (False, 29),
 (False, 20),
 (False, 26),
 (True, 21),
 (True, 20),
 (False, 23),
 (False, 24),
 (False, 24),
 (False, 28),
 (True, 21),
 (False, 25),
 (False, 20),
 (False, 20),
 (True, 21),
 (False, 26),
 (False, 26),
 (False, 24),
 (False, 27),
 (False, 20),
 (False, 24),
 (True, 20),
 (True, 21),
 (False, 24),
 (False, 26),
 (False, 25),
 (False, 22),
 (False, 21),
 (False, 28),
 (False, 23),
 (False, 23),
 (False, 21),
 (False, 23),
 (True, 21),
 (False, 22),
 (False, 23),
 (False, 24),
 (False, 24),
 (True, 21),
 (False, 22),
 (False, 23),
 (False, 23),
 (False, 26),
 (False, 29),
 (False, 26),
 (False, 24),
 (True, 20),
 (False, 29),
 (True, 21),
 (True, 20),
 (True, 21),
 (True, 21),
 (True, 20),
 (False, 24),
 (False, 26),
 (True, 21),
 (True, 20),
 (True, 20),
 (False, 24),
 (False, 23),
 (True, 21),
 (False, 22),
 (True, 20),
 (False, 25),
 (False, 22),
 (False, 27),
 (True, 21),
 (True, 21),
 (False, 25),
 (False, 25),
 (False, 24),
 (False, 23),
 (True, 21),
 

In [16]:
episodes

[[((14, 10, False), 1, 0.0), ((20, 10, False), 0, 1.0)],
 [((19, 10, False), 1, -1.0)],
 [((4, 3, False), 1, 0.0),
  ((15, 3, True), 1, 0.0),
  ((17, 3, True), 1, 0.0),
  ((20, 3, True), 0, -1.0)],
 [((17, 5, False), 1, -1.0)],
 [((21, 10, True), 0, 1.0)],
 [((10, 4, False), 1, 0.0),
  ((15, 4, False), 1, 0.0),
  ((20, 4, False), 0, 1.0)],
 [((6, 10, False), 1, 0.0), ((13, 10, False), 1, -1.0)],
 [((14, 10, False), 1, -1.0)],
 [((18, 10, False), 1, -1.0)],
 [((7, 10, False), 1, 0.0),
  ((9, 10, False), 1, 0.0),
  ((18, 10, False), 1, -1.0)],
 [((7, 4, False), 1, 0.0), ((11, 4, False), 1, 0.0), ((21, 4, False), 0, 1.0)],
 [((16, 5, False), 1, -1.0)],
 [((10, 2, False), 1, 0.0), ((20, 2, False), 0, 0.0)],
 [((15, 10, False), 1, 0.0), ((20, 10, False), 0, 0.0)],
 [((14, 5, False), 1, 0.0), ((21, 5, False), 0, 1.0)],
 [((8, 10, False), 1, 0.0), ((18, 10, False), 1, -1.0)],
 [((18, 10, True), 1, 0.0),
  ((13, 10, False), 1, 0.0),
  ((16, 10, False), 1, -1.0)],
 [((14, 10, False), 1, -1.0)],

In [17]:
win_times = np.sum(np.fromiter((r[0] for r in result), dtype="int"))
print(f"win: {win_times}")
print(f"win rate: {win_times/TOTAL_EPISODES}")


win: 296
win rate: 0.296
