In [2]:
from dataclasses import dataclass, field
from typing import List, Dict, Tuple
from enum import Enum
from copy import deepcopy

import numpy as np

import blackjack

In [3]:
@dataclass
class State:
    hand_total: int
    delear_partial_total: int
    
    def __hash__(self) -> int:
        return hash((self.hand_total, self.delear_partial_total))

In [4]:
class Action(Enum):
    HIT = 0
    STAND = 1

In [5]:
@dataclass
class QTable:
    table: Dict[State, List[float]] = field(default_factory=dict)
    
    def __post_init__(self) -> None:
        possible_actions = [Action.HIT, Action.STAND]
        possible_dealer_visible_totals = list(range(1, 12)) # Between 1 and 11
        possible_hand_totals = list(range(2, 23)) # If hand_total > 21, then it is 22 in the q table, no matter how bigger than 21 it is (not that it will do anything since the game will be automatically ended)
        
        for total in possible_hand_totals:
            for dealer_total in possible_dealer_visible_totals:
                state = State(total, dealer_total)
                self.table[state] = [0.0] * len(possible_actions)
    
    
    def __getitem__(self, state: State) -> List[float]:
        if state.hand_total > 21:
            state.hand_total = 22
        return self.table[state]
    
    def __setitem__(self, state: State, value: List[float]):
        if state.hand_total > 21:
            state.hand_total = 22
        self.table[state] = value
                
    def policy(self, state: State, epsilon: float = None) -> Action:
        if state.hand_total > 21:
            return Action.STAND
        if epsilon:
            if np.random.random() < epsilon:
                return Action(np.random.choice([0, 1]))
        return Action(np.argmax(self[state]))

In [10]:
class Agent:
    def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.99, epsilon_greedy: float = 0.1, epochs: int = 1000) -> None:
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon_greedy = epsilon_greedy
        self.epochs = epochs
        
        self.q_table = QTable()
        self.table = blackjack.Table()
        self.states: List[State] = [State(-1, -1)]
        self.episodes: List[List[Dict[str, State | Action | float]]] = []
    
    
    def reward_win_loss(self, player_total: int, dealer_total: int) -> float:
       
        if player_total > 21:
            return -1
        
        if player_total > dealer_total:
            return 1
        
        if dealer_total > 21:
            return 1
        
        if dealer_total > player_total:
            return -1
        
        if player_total == dealer_total:
            return 0
        
        
        
        
        
    def reward_proximity(self, player_total: int, dealer_total: int) -> float:
        player_total = player_total if player_total < 22 else 0
        return 1 / (21 - player_total + 1e-1)
                      

    
        
    def play_hand_recursive(self, players: List[blackjack.RecursivePlayer] = [], model_player: blackjack.Player = None) -> None:
        if len(players) == 0:
            players = [blackjack.RecursivePlayer(chosen_total, model_player.hand, actions_taken=[]) for chosen_total in model_player.hand.possible_totals]
        else:
            #return players
            new_players = []
            for player in players:
                if player.playing:
                    for total in player.possible_totals:
                        new_players.append(blackjack.RecursivePlayer(total, player.hand, player.actions_taken, playing=player.playing))
                        
                else:
                    reward = self.reward_win_loss(min(player.hand.possible_totals), self.table.dealer.total)
                    for action in player.actions_taken:
                        action["reward"] = reward
                        action["state"] = State(action["total"], self.table.dealer.partial_total)
                        action["current_total"] = min(player.hand.possible_totals)
                        if action not in self.episodes:
                            self.episodes.append(action)
                    
            players = new_players
        
       
        if len(players) == 0:
            return 
            
        
        
        next_card = self.table.deck.draw()
        
        for i, player in enumerate(players):
            if player.playing:
                action = self.q_table.policy(State(player.chosen_total, self.table.dealer.partial_total), self.epsilon_greedy)
                if player.chosen_total > 21:
                    player.stand()
                    continue
                
                elif action == Action.HIT:  
                    player.hit(next_card)
                else:   
                    player.stand()

        
                
        return self.play_hand_recursive(players)
    
    def decay_epsilon(self) -> None:
        self.epsilon_greedy *= 0.99965
    
    def update_q_table(self) -> None:
        for episode in self.episodes:
            state = episode["state"]
            action = episode["action"]
            reward = episode["reward"]
            new_total = episode["new_total"]
            #print((1 - self.learning_rate) * self.q_table[state][action.value] + self.learning_rate * (reward + self.discount_factor * max(self.q_table[state]) - self.q_table[state][action.value]))
            self.q_table[state][action.value] = (1 - self.learning_rate) * self.q_table[state][action.value] + self.learning_rate * (reward + self.discount_factor * max(self.q_table[State(new_total, self.table.dealer.partial_total)]) - self.q_table[state][action.value])
            
    
    def train(self) -> None:
        for epoch in range(self.epochs):
            self.table = blackjack.Table()
            model_player = blackjack.Player()
            self.table.add_player(model_player)
            self.table.start_turn()
            self.table.dealer.play(self.table.deck)
            self.play_hand_recursive(model_player=model_player)
            self.update_q_table()
            #self.decay_epsilon()
            
            if epoch % 250 == 0:
                table = np.array(list(self.q_table.table.values()))
                expected_value = np.mean(table[np.argmax(table, axis=1)])
                print(f"Epoch: {epoch}")
                print(f"Ev: {expected_value}")


In [11]:
agent = Agent(
    learning_rate=0.1,
    discount_factor=0.99,
    epsilon_greedy=0.3,
    epochs=10000
)



In [12]:
agent.train()

Epoch: 0
Ev: 0.0
Epoch: 250
Ev: 0.0
Epoch: 500
Ev: 0.0
Epoch: 750
Ev: 0.0
Epoch: 1000
Ev: 0.0
Epoch: 1250
Ev: 0.0
Epoch: 1500
Ev: 0.0
Epoch: 1750
Ev: -0.207967378766778
Epoch: 2000
Ev: -0.20738184524775088
Epoch: 2250
Ev: -0.2626420216860046
Epoch: 2500
Ev: -0.09972374779962462
Epoch: 2750
Ev: -0.1935127319976479
Epoch: 3000
Ev: -0.11724858935906113
Epoch: 3250
Ev: -0.16638628029956032
Epoch: 3500
Ev: -0.03483174383839696
