# Blackjack - prediction

In this notebook, we are interested in *learning* the value-function $v_\pi(s)$ and action-value function $q_\pi(s, a)$ for a given policy $\pi$

In [1]:
from numba import jit
import numpy as np

In [2]:
np.set_printoptions(suppress=True)

* **Stick**: Player stops
* **Hit**: Request an additional card

In [3]:
n_vals = np.ones(10)
# n_vals[-1] = 4
deck_probs = n_vals / n_vals.sum()

In [4]:
from numba.core import types
from numba.typed import Dict

In [5]:
@jit(nopython=True)
def draw_card():
    return np.random.multinomial(1, deck_probs).argmax()


@jit(nopython=True)
def dealer_strategy(value_cards):
    """
    Dealer's fixed strategy
    """
    while value_cards < 17:
        value_cards = value_cards + draw_card()
    return value_cards


@jit(nopython=True)
def update_player_card(current_value, new_card):
    """
    Update the player's hand value.
    
    Parameters
    ----------
    current_value: int
        Current value of player's hand
    new_card: int
        Randomly-drawn card
    
    Returns
    -------
    tuple: (int, bool)
        1. New value of player's hand
        2. Whether the there is a usable ace.
    """
    has_usable_ace = False
    if new_card == 1.0 and current_value <= 10:
        new_card = 11
        has_usable_ace = True
        
    new_value = current_value + new_card
    return new_value, has_usable_ace

@jit(nopython=True)
def blackjack(player_value_cards, dealer_cards, policy, has_usable_ace):
    """
    Evaluate a single play of Blackjack.
    
    For some reason, a player can only have a minimum value of 12
    on her initial value cards.
    
    At the start of the game, we are given the initial value of the cards
    of the player, the initial dealer cards and a policy for the player.
    Furtheremore, we are given whether the player has a usable ace.
    
    Actions:
        0: hit
        1: stick
    

    Parameters
    ----------
    player_value_cards: float
        Current value cards for the players
    dealer_cards: jnp.array(2)
        Dealer's initial cards
    policy: jnp.array(G,A)
        2d-array specifying if having value g ∈ G the player
        should take action a ∈ A, i.e., policy[g,a] == 1.0
        if action a should be taken if the value of the cards
        is g.
    has_usable_ace: bool
        Wehether the initial player_value_cards contains a
        usable ace.
    """
    reward = 0
    
    hist_reward = [reward]
    hist_state = [(player_value_cards, has_usable_ace)]
    hist_action = [0]
    
    
    dealer_value_cards = np.sum(dealer_cards)
    
    # Stick if you have 21
    if player_value_cards == 21 and dealer_value_cards != 21:
        reward = 1
        
        hist_reward.append(reward)
        hist_state.append((player_value_cards, has_usable_ace))
        hist_action.append(1)
        
        return reward
    
    # Strickly speaking, the policy should depend on:
    #  1. The current state of the player, i.e., the value of her cards
    #  2. The only card we observe of the dealer
    # In this example, we consider a policy that only depends
    # on the current value of the player's cards.

    # Hit until you reach a 'stick' state or you lose (value of cards over 21)
    while policy[player_value_cards - 12][1] != 1.0:
        new_card = draw_card()
        player_value_cards, new_has_usable_ace = update_player_card(player_value_cards, new_card)
        has_usable_ace = has_usable_ace or new_has_usable_ace # keep usable ace if player already did have one.
        hist_reward.append(0)
        hist_action.append(0)
        hist_state.append((player_value_cards, has_usable_ace))
        
        if player_value_cards > 21:
            break
            
    dealer_value_cards = dealer_strategy(sum(dealer_cards))
    
    if player_value_cards > 21:
        reward = -1
    elif dealer_value_cards > 21:
        reward = 1
    else:
        reward = 1 if player_value_cards > dealer_value_cards else 0
    
    hist_reward.append(reward)
    
    return reward, hist_reward

In [6]:
# We consider the policy that sticks if the 
# player's sum is 20 or 21, and otherwise hits
policy = np.zeros((10, 2))
policy[:-2, 0] = 1
policy[-2:, 1] = 1
policy

array([[1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [0., 1.],
       [0., 1.]])

## A Jax implementation

In [2]:
import jax
import jax.numpy as jnp
from jax.experimental import loops

In [42]:
# sum of player's hand (Sp) | dealer's shown card (Cd) | whether player has usable card (A)
state_0 = jnp.array([20, 2, 1])

In [52]:
hand_sum, dealer_card, has_ace = state_0
policy[hand_sum, dealer_card, has_ace]

DeviceArray(1., dtype=float32)

In [53]:
# We consider an initial policy that sticks if the
# player's sum is 20 or 21 and otherwise hits
# Sp | Cd | A
policy = jnp.zeros((10, 10, 2))
policy = (policy.at[:-2, :, 0].set(1) # Sum less than 20 => hits
                .at[-2:, :, 1].set(1)) # Sum == 20 or 21 => stick

In [184]:
jax.random.split(key)

DeviceArray([[ 407137227, 1028236956],
             [2469535296,  588450469]], dtype=uint32)

In [182]:
# 9 cards valued 1 through 9; 10, Jack, Queen and King are worth 10

def draw_cards(key, n=1):
    n_vals = jnp.ones(10).at[-1].set(4)
    deck_probs = n_vals / n_vals.sum()
    cards = jax.random.choice(key, 10, (n,), p=deck_probs)
    return cards
    

def init_player_state(key):
    """
    Random initialisiation of player's card. At the beginning we draw two cards.
    If the sum of the cards is less than 12, we continue hitting until we reach
    at least 12.
    
    Parameters
    ----------
    key: jax.random.PRNGKey
        Initial player's key
    
    Returns
    -------
    tuple(
        1. initial value of cards
        2. whether there was a usable ace
    )
    """
    init_cards = draw_cards(key) + 1

    value_cards = init_cards.sum()
    usable_ace = (init_cards[0] == 1) + (init_cards[1] == 1)
    value_cards = init_cards.sum() + 10 * usable_ace
    
    with loops.Scope() as s:
        s.value = value_cards
        s.usable_ace = usable_ace
        _, new_key = jax.random.split(key)
        for _ in s.while_range(lambda: s.value < 12):
            new_card = jax.random.choice(new_key, 10, p=deck_probs)
            new_card, usable_ace = jax.lax.cond((s.value <= 10) * (new_card == 1), lambda: (11, True), lambda: (1, False))
            s.usable_ace = s.usable_ace | usable_ace
            s.value = s.value + new_card
            _, new_key = jax.random.split(new_key)
    
    value_cards = s.value
    usable_ace = s.usable_ace
    
    return value_cards, usable_ace

In [183]:
key = jax.random.PRNGKey(12)
keys = jax.random.split(key, 12)


init_values, usable_aces = jax.vmap(init_player_state)(keys)
for init_val, usable_ace in zip(init_values, usable_aces):
    print(f"{init_val:02}", f"{usable_ace:01}")

12 0
12 0
12 0
12 0
12 0
12 0
12 0
12 0
12 0
12 0
12 0
12 0


In [None]:
def blackjack_step(state, policy):
    """
    Compute a single step of blackjack.
    
    Parameters
    ----------
    state: jnp.array(3)
        * Player's current sum
        * Dealer's one-showing card
        * Whether player has usable ace
    policy: jnp.array(N,M,K)
        Policy grid specifying where to "hit" or "stick"
    
    Returns
    -------
    tuple:
        * Array of next state
        * Reward of the game
        * whether the game has reached an end-state.
    """
    ...

In [None]:
def init_game(key):
    key_player, key_dealer = jax.random.split(key)
    player_state = init_player_state(key_player)
    dealer_card = 

In [168]:
jnp.array(init_player_state(key))

DeviceArray([12,  0], dtype=int32)