In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
class BalckJack:
    def __init__(self, dealer_stick_at=17, player_stick_at=20, gamma=1):
        self.dealer_stick_at = dealer_stick_at
        self.player_stick_at = player_stick_at
        self.gamma = gamma
            
    def reset(self):
        self.action_space = {'hit':0, 'stick':1}
        self.rewards = {'win':1, 'lose':-1, 'draw':0}
        self.cards = np.array(['2', '3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A'])
        self.policy_player = np.zeros(22, dtype=np.int32)
        for i in range(self.player_stick_at, 22):
            self.policy_player[i] = self.action_space['stick']
        self.policy_dealer = np.zeros(22, dtype=np.int32)
        for i in range(self.dealer_stick_at, 22):
            self.policy_dealer[i] = self.action_space['stick']

        player_ace_num = 0
        player_cards = []
        for i in range(2):
            card, ace = self.get_card()
            player_cards.append[card]
            player_ace_num += ace
        player_card_sum, player_ace_num = self.check_ace(sum(player_cards), player_ace_num)
        dealer_current_card, _ = self.get_card()

        return [player_ace_num > 0, player_card_sum, dealer_current_card]

    def get_card(self):
        card = np.random.choice(self.cards)
        if card == 'A':
            return 11, 1 # count ace as 11, number of ace
        elif card in ['J', 'Q', 'K']:
            return 10, 0
        else:
            return int(card), 0

    def check_ace(self, card_sum, ace_num):
        while card_sum > 21 and ace_num > 0:
            card_sum -= 10
            ace_num -= 1
        return card_sum, ace_num

    def result(self, player_sum, dealer_sum):
        if player_sum == 21:
            if dealer_sum != 21:
                return 'win'
            else:
                return 'draw'
        elif player_sum < 21:
            if dealer_sum < 21:
                if abs(21 - player_sum) < abs(21 - dealer_sum):
                    return 'win'
                elif abs(21 - player_sum) == abs(21 - dealer_sum):
                    return 'draw'
                else:
                    return 'lose'
            elif dealer_sum == 21:
                return 'lose'
            else:
                return 'win'
        else:
            return 'lose'    

    def play(self, state):
        trajectory = []
        usable_ace, player_card_sum, dealer_card = state
        if player_card_sum == 21:
            return
        else:
            player_action = self.policy_player[player_card_sum]
            player_ace_num = int(usable_ace)
            reward = 0

        # player
        while True:
            trajectory.append[(state, player_action, reward)]
            if player_action == self.action_space['stick']:
                break
            player_card, player_ace = self.get_card()
            player_card_sum, player_ace_num = self.check_ace(player_card_sum + player_card, player_ace_num + player_ace)
            if player_card_sum >= 21:
                break
            state = [player_ace_num > 0, player_card_sum, dealer_card]
            player_action = self.policy_player(player_card_sum)


        # dealer
        dealer_card_sum = dealer_card
        dealer_ace_num = 1 if dealer_card == 11 else 0
        while True:
            dealer_new_card, dealer_ace = self.get_card()
            dealer_card_sum, dealer_ace_num = self.check_ace(dealer_card_sum + dealer_new_card, dealer_ace_num + dealer_ace)
            if dealer_card_sum > 21:
                break
            dealer_action = self.policy_dealer[dealer_card_sum]
            if dealer_action == self.action_space['stick']:
                break

        reward = self.result(player_card_sum, dealer_card_sum)
        trajectory[-1][-1] = reward

        return trajectory
  

In [None]:
def monte_carlo_first_visit(episodes): 
    V = np.zeros([2, 22, 10]) # usable ace, player's current sum <=21, dealer's card (A~10)
    Returns = np.zeros([2, 22, 10])
    count = np.zeros([2, 22, 10])
    black_jack = BalckJack()
    GAMMA = black_jack.gamma

    for _ in tqdm(range(episodes)):
        visited_state = []
        state = black_jack.reset()
        trajectory = black_jack.play(state)

        G = 0 
        for S, A, R in trajectory[::-1]:
            G = GAMMA * G + R
            if S not in visited_state:
                usable_ace, player_sum, dealer_card = S
                dealer_current_card = dealer_card - 10 if dealer_card == 11 else dealer_card
                Returns[int(usable_ace), player_sum, dealer_current_card] += G
                count[int(usable_ace), player_sum, dealer_current_card] += 1
                visited_state.append(S)
        V = Returns / count

    return V



    