# [Easy21](https://www.davidsilver.uk/wp-content/uploads/2020/03/Easy21-Johannes.pdf) solution

In [2]:
from enum import Enum
import logging
import random

import pandas as pd
import numpy as np

## Implement the environment

In [3]:
class Action(Enum):
    STICK = 0
    HIT = 1
    
class Color(Enum):
    RED = 0
    BLACK = 1
    
class Card(object):   
    def __init__(self, color, value):
        self.color = color
        self.value = value
        
    def __str__(self):
        return "Card: {} {}".format(self.color, self.value)
    
class Reward(object):
    LOSS = -1
    DRAW = 0
    WIN = 1
    
class State(object):
    def __init__(self,
                 dealer_first_card=None,
                 player_sum=None,
                 is_terminal=False):
        self.dealer_first_card = dealer_first_card
        self.player_sum = player_sum
        self.is_terminal = is_terminal
        
    def __str__(self):
        return ("State:\n\tDealer's first card: {}\n\t"
                "Player sum: {}\n\t"
                "Is terminal: {}\n".format(
            self.dealer_first_card,
            self.player_sum,
            self.is_terminal
        ))
    
def draw(v = None, c = None):
    """Returns a card from an infinite deck of cards.
    
    Each draw results in a value between 1 and 10 (uniformly distributed)
    with a color of red (p = 1/3) or black (p = 2/3).
    
    Args:
        v: A pre-determined value.
        c: A pre-determined color.
    """
    value = v or random.randint(1, 10)
    color = c or (Color.RED if random.random() < 1/3 else Color.BLACK)
    
    return Card(color=color, value=value)
   
class Round(object):
    BUST_LOWER_BOUND = 1
    BUST_UPPER_BOUND = 21
    
    # The dealer sticks on any sum of 17 or greater and hits otherwise.
    DEALER_STICK_VALUE = 17
    
    def __init__(self):
        pass
    
    @staticmethod
    def get_deck_sum(deck):
        """Returns the sum of a list of Card elements.
        
        Values are added (black) or subtracted (red).
        """
        s = sum([c.value for c in deck if c.color == Color.BLACK])
        s -= sum([c.value for c in deck if c.color == Color.RED])
        
        return s
    
    @staticmethod
    def is_busted(deck_sum):
        return (deck_sum > Round.BUST_UPPER_BOUND or \
           deck_sum < Round.BUST_LOWER_BOUND)
    
    @staticmethod
    def update_sum(cur_sum, card):  
        card_value = card.value if card.color == Color.BLACK \
                                else -1 * card.value
        return cur_sum + card_value
    
    def step(self, s, a):
        """Returns a sample of the next state s' and reward r.
    
        Args:
            s: A state.
            a: An action.
        """
        # If the player hits...
        if a == Action.HIT:
            # She draws another card from the deck.
            card = draw()
            
            logging.debug("Player hits and draws {}".format(card))
                
            # The value of the player's cards are added (black) or
            # subtracted (red cards).
            player_sum = Round.update_sum(s.player_sum, card)
            
            logging.debug("Player sum: {}".format(player_sum))
                    
            ns = State(
                    dealer_first_card=s.dealer_first_card,
                    player_sum=player_sum,
                    is_terminal=False)
            
            if Round.is_busted(player_sum):
                logging.debug("Busted! Womp womp.")
                ns.is_terminal = True
                return (ns, Reward.LOSS)
            
            else:
                logging.debug("Next state: {}".format(ns))
                    
                return (ns, 0) # Game is still going, no reward.
            
        # If the player sticks:
        else:
            # She receives no further cards.
            player_sum = s.player_sum
            
            logging.debug("Player is sticking with sum {}".format(
                    player_sum))
                
            dealer_sum = s.dealer_first_card.value
            
            # The dealer starts taking turns.
            while (dealer_sum < Round.DEALER_STICK_VALUE and \
                  not(Round.is_busted(dealer_sum))):
                
                    
                # Dealer hits.
                card = draw()
                
                dealer_sum = Round.update_sum(dealer_sum, card)
                logging.debug("\tDealer draws {}, sum: {}".format(
                    card, dealer_sum))
            
            if Round.is_busted(dealer_sum):
                logging.debug("Dealer busted!")
            else:
                logging.debug("Dealer is sticking with sum {}".format(
                    dealer_sum))
            
            ns = State(is_terminal=True)

            if Round.is_busted(dealer_sum):
                r = Reward.WIN
            # This probably will not be triggered.
            elif Round.is_busted(player_sum):
                r = Reward.LOSS
            elif player_sum == dealer_sum:
                r = Reward.DRAW
            else:
                r = Reward.WIN if player_sum > dealer_sum \
                    else Reward.LOSS
                
            logging.debug("Reward: {}".format(r))
            
            return (ns, r)

## Tests

In [4]:
# Verify draw distributions.
numbers = [0] * 10
colors = [0] * 2

for i in range(100000):
    card = draw()
    # Expected to be uniform.
    numbers[card.value - 1] += 1
    
    # Expected to be 1/3, 2/3.
    colors[card.color.value] += 1
    
print(numbers)
print(colors)

[9994, 9969, 9988, 9840, 9814, 10074, 10193, 10075, 10012, 10041]
[33176, 66824]


In [5]:
# Verify Round.get_deck_sum() 
cards = [Card(Color.BLACK, 1), Card(Color.RED, 4), Card(Color.BLACK, 3)]
# Expected 0 (1 + 3 - 4)
print(Round.get_deck_sum(cards))

# Verify is_busted().
cards = [Card(Color.BLACK, 23), Card(Color.RED, 4), Card(Color.BLACK, 3)]
s = sum([c.value for c in cards])
print(Round.is_busted(s))

0
True


In [7]:
def strategy_always_hit(s):
    return Action.HIT

def strategy_always_stick(s):
    return Action.STICK

def strategy_dealer_mirror(s):
    """This strategy also sticks at 17+."""
    if s.player_sum >= 17:
        return Action.STICK
    return Action.HIT

def play_game(strategy, show_debug=False):
    if show_debug:
        logging.getLogger().setLevel(logging.DEBUG)
    else:
        logging.getLogger().setLevel(logging.INFO)
        
    # Initialize.
    # Player starts with a black card.
    player_first_card = draw(c=Color.BLACK)
    player_sum = player_first_card.value
    logging.debug("Player first card: {}".format(player_first_card))

    # Dealer starts with a black card.
    dealer_first_card = draw(c=Color.BLACK)
    logging.debug("Dealer first card: {}".format(dealer_first_card))

    game = Round()
    s = State(dealer_first_card=dealer_first_card,
          player_sum=player_sum,
          is_terminal=False)

    while (not s.is_terminal):
        ns, r = game.step(s, strategy(s))
    
        is_terminal = ns.is_terminal
        s = ns

    return r

def play_games(n, strategy, show_debug=False):
    rs = []
    for i in range(n):
        r = play_game(strategy, show_debug)
        rs.append(r)
        
    return rs

# r = play_game(strategy_dealer_mirror, show_debug=False)
# print("Final reward: {}".format(r))

rs = play_games(5000, strategy_dealer_mirror, show_debug=False)
print("Average score: {}".format(sum(rs)/len(rs)))

Average score: -0.3006
