# Off-policy estimation 

In [2]:
import numpy as np
import pandas as pd
import blackjack as b21
import matplotlib.pyplot as plt
from numba import njit, prange

In [3]:
%config InlineBackend.figure_format = "retina"
%load_ext autoreload
%autoreload 2

## Example 5.4: Off-policy estimation of a Blackjack state value

In [4]:
# We consider the policy that sticks if the player's sum is 20 or 21
# and sticks otherwise
policy_init = np.zeros((
    b21.PLAY_MAXVAL - b21.PLAY_MINVAL + 1, # Player's value
    2, # has usable ace
    10, # Dealer's one showing card
    2, # hit or stick
))

policy_init[..., 1] = 0.5 # hit for v < 20
policy_init[..., 0] = 0.5 # stick for v >= 20
# policy.shape

In [5]:
value_cards_player = 13
has_usable_ace = True
dealers_card = 2
b21.play_single_hist(value_cards_player, has_usable_ace, dealers_card, policy_init)

(-1, (28, 20), ([[13, 1, 2], [19, 1, 2], [28, 1, 2]], [1, 1, 0], [0, 0, -1]))

## The target policy

In [6]:
# We consider the policy that sticks if the player's sum is 20 or 21
# and sticks otherwise
policy_target = np.zeros((
    b21.PLAY_MAXVAL - b21.PLAY_MINVAL + 1, # Player's value
    2, # has usable ace
    10, # Dealer's one showing card
    2, # hit or stick
))

policy_target[:-2, ..., 1] = 1 # hit for v < 20
policy_target[-2:, ..., 0] = 1 # stick for v >= 20
# policy.shape

In [7]:
21 - 4 + 1

18

In [8]:
policy_target.shape

(18, 2, 10, 2)

In [9]:
@njit(parallel=True)
def multiple_runs_target(n_runs):
    rewards = 0.0
    counts = 0.0
    for n in prange(n_runs):
        reward, _ = b21.play_single(value_cards_player, has_usable_ace, dealers_card, policy_target)
        rewards += reward
        counts += 1
    return rewards, counts

In [10]:
%%time
n_sims = 100_000_000
rewards_target, count = multiple_runs_target(n_sims)

CPU times: user 54.1 s, sys: 26.3 ms, total: 54.1 s
Wall time: 2.4 s


In [11]:
rewards_target / count

-0.35914118

In [12]:
format(count, "0.0e")

'1e+08'

In [14]:
b21.state_to_ix(25, 1, 1)

(21, 1, 0)

In [118]:
r, value, hist = b21.play_single_hist(value_cards_player, has_usable_ace, dealers_card, policy_target)

print(r, end="\n" * 2)
print(value, end="\n" * 2)
print(hist)

0

(20, 20)

([[13, 1, 2], [20, 1, 2], [20, 1, 2]], [1, 0, 0], [0, 0, 0])
