In [19]:
from dataclasses import dataclass, field
from typing import List, Dict, Tuple
from enum import Enum
from copy import deepcopy

import numpy as np

import blackjack

In [20]:
@dataclass
class State:
    hand_total: int
    delear_partial_total: int
    
    def __hash__(self) -> int:
        return hash((self.hand_total, self.delear_partial_total))

In [21]:
class Action(Enum):
    HIT = 0
    STAND = 1

In [22]:
@dataclass
class QTable:
    table: Dict[State, List[float]] = field(default_factory=dict)
    
    def __post_init__(self) -> None:
        possible_actions = [Action.HIT, Action.STAND]
        possible_dealer_visible_totals = list(range(1, 12)) # Between 1 and 11
        possible_hand_totals = list(range(2, 23)) # If hand_total > 21, then it is 22 in the q table, no matter how bigger than 21 it is (not that it will do anything since the game will be automatically ended)
        
        for total in possible_hand_totals:
            for dealer_total in possible_dealer_visible_totals:
                state = State(total, dealer_total)
                self.table[state] = [0.0] * len(possible_actions)
    
    
    def __getitem__(self, state: State) -> List[float]:
        if state.hand_total > 21:
            state.hand_total = 22
        return self.table[state]
    
    def __setitem__(self, state: State, value: List[float]):
        if state.hand_total > 21:
            state.hand_total = 22
        self.table[state] = value
                
    def policy(self, state: State, epsilon: float = None) -> Action:
        if epsilon:
            if np.random.random() < epsilon:
                return Action(np.random.choice([0, 1]))
        return Action(np.argmax(self[state]))

In [23]:
class Agent:
    def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.99, epsilon_greedy: float = 0.1, epochs: int = 1000) -> None:
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon_greedy = epsilon_greedy
        self.epochs = epochs
        self.q_table = QTable()
        self.table = blackjack.Table()
        self.states: List[State] = [State(-1, -1)]
        self.episodes: List[List[Dict[str, State | Action | float]]] = []
    
    
    def reward_win_loss(self, player_total: int, dealer_total: int) -> float:
        if player_total > 21:
            return -1
        elif player_total > dealer_total:
            return 1
        elif player_total == dealer_total:
            return 0
        else:
            return -1
        
    def reward_proximity(self, player_total: int, dealer_total: int) -> float:
        player_total = player_total if player_total < 22 else 0
        return 1 / (21 - player_total + 1e-1)
                      

    
        
    def play_hand_recursive(self, players: List[blackjack.RecursivePlayer] = [], model_player: blackjack.Player = None) -> None:
        
        if len(players) == 0:
            players = [blackjack.RecursivePlayer(chosen_total, model_player.hand) for chosen_total in model_player.hand.possible_totals]
        else:
            new_players = []
            for player in players:
                if player.playing:
                    for total in player.possible_totals:
                        new_players.append(blackjack.RecursivePlayer(total, player.hand, player.actions_taken, playing=player.playing))
                        
                else:
                    reward = self.reward_proximity(min(player.hand.possible_totals), self.table.dealer.total)
                    for action in player.actions_taken:
                        action["reward"] = reward
                        action["state"] = State(action["total"], self.table.dealer.partial_total)

                        if action not in self.episodes:
                            self.episodes.append(action)
                    
            players = new_players
        
        print("--" * 50)
        print([player.chosen_total for player in players])
        print([player.actions_taken for player in players])
        print([player.possible_totals for player in players])
        print("--" * 50)
        if len(players) == 0:
            return 
            
        
        
        next_card = self.table.deck.draw()
        
        for i, player in enumerate(players):
            print("action taken")
            if not player.playing:
                continue
            
            action = self.q_table.policy(State(player.chosen_total, self.table.dealer.partial_total), self.epsilon_greedy)
            if action == Action.HIT:  
                player.hit(next_card)
            else:   
                player.stand()

        
                
        return #self.play_hand_recursive(players)
            

            
    
        

In [24]:
agent = Agent(
    learning_rate=0.1,
    discount_factor=0.99,
    epsilon_greedy=0.9,
    epochs=1000
)

player = blackjack.Player()
agent.table = blackjack.Table()
agent.table.add_player(player)
agent.table.start_turn()
agent.table.dealer.play(agent.table.deck)

agent.play_hand_recursive(model_player=player)
#agent.episodes

----------------------------------------------------------------------------------------------------
[16]
[[{'action': <Action.HIT: 0>, 'total': 11, 'reward': 0}, {'action': <Action.STAND: 1>, 'total': 7, 'reward': 0}, {'action': <Action.STAND: 1>, 'total': 7, 'reward': 0}]]
[[16]]
----------------------------------------------------------------------------------------------------
action taken
