In [1]:
from dataclasses import dataclass, field
from typing import List, Dict, Tuple
from enum import Enum
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt

import blackjack

In [2]:
%matplotlib inline  

In [8]:
@dataclass
class State:
    hand_total: int
    dealer_total: int
    #usable_ace: bool
    
    def __hash__(self) -> int:
        return hash((self.hand_total, self.dealer_total))

In [4]:
@dataclass
class QTable:
    q_table: Dict[State, List[float]] = field(default_factory=dict)
    
    def __post_init__(self) -> None:
        possible_player_totals = np.arange(2, 22)
    
    def __getitem__(self, state: State) -> List[float]:
        if state.hand_total > 21:
            return
        return self.q_table[state]
    
    def __setitem__(self, state: State, value: List[float]):
        if state.hand_total > 21:
            return
        self.q_table[state] = value
        
    def policy(self, state: State) -> int:
        if state.hand_total >= 21:
            return 1
        return np.argmax(self.q_table[state])

In [5]:
class Agent:
    def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.9, epsilon_greedy: float = 0.1, training_epochs: int = 1000) -> None:
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon_greedy = epsilon_greedy
        self.training_epochs = training_epochs
        self.current_epoch = 0
        
        self.q_table = QTable()
        
        self.deck = blackjack.Deck()
        self.dealer = blackjack.Dealer()
        self.player = blackjack.Player()    
        
        self.state = State(0, 0)
        
    def action(self, state: State) -> int:
        if np.random.random() < self.epsilon_greedy:
            return np.random.choice([0, 1])
        else:
            return self.q_table.policy(state)
    
    
    def start_turn(self) -> None:
        self.deck.reset()
        self.player.clear_hand()
        self.dealer.clear_hand()
        
        self.player.hit(self.deck)
        self.player.hit(self.deck)
        self.dealer.hit(self.deck)
        self.dealer.hit(self.deck)
        
    def play(self):
        self.start_turn()
        self.dealer.play(self.deck)
        self.state = State(self.player.hand.total, self.dealer.hand.total)
        
        while self.player.playing:
            action = self.action(self.state)
            self.player.play(action, self.deck)
            self.state.hand_total = self.player.hand.total
    
        print(self.player.actions_taken)
        
        
        

In [6]:
agent = Agent(
    learning_rate=0.1,
    discount_factor=0.9,
    epsilon_greedy=0.1,
    training_epochs=1000
)

In [7]:
agent.play()

TypeError: unhashable type: 'State'