In [7]:
%pylab notebook
import numpy as np
from copy import copy

Populating the interactive namespace from numpy and matplotlib


In [6]:
def random_choice(a, p=None):
    return list(a)[np.random.choice(len(a), p=p)]

def trajectory(s, R, policy):
    traj = [ ]
    while True:
        a = policy(s)
        traj.append((s, a))
        if terminal(s):
            break
        s, r = R(s, a)
    return traj

def mc_es(S, terminal, A, R, gamma, Q, policy, niter):
    
    count = { }
    
    for _ in xrange(niter):
        # Generate episode
        episode = []
        s = random_choice(S)
        a = random_choice(A(s))
        r = 0.
        while True:
            episode.append((s, a, r))
            s, r = R(s, a)
            if terminal(s):
                episode.append((s, None, r))
                break
            a = policy[s]
        
        G = 0.
        Gsa = { }
        for (s, a, r) in episode[::-1]:
            G += r
            Gsa[s, a] = G
        for ((s, a), G) in Gsa.items():
            n = count[s, a] = count.get((s, a), 0) + 1
            Q[s, a] = ((n - 1) * Q.get((s, a), 0) + G) / n
        for s in set(s for (s, a, r) in episode):
            policy[s] = max(A(s), key = lambda a: Q[s, a])


In [61]:
# Blackjack

GAMMA = 1.

S = [ (player_sum, dealer_sum, usable_ace, terminal) 
     for player_sum in xrange(1, 1+31) 
     for dealer_sum in xrange(1, 11) 
     for usable_ace in (False, True) # True if player_sum can be increased by 10 by using an ace
     for terminal in (False, True)]

A = lambda (player_sum, dealer_sum, usable_ace, terminal): (False, True) if (player_sum < 21) and (not terminal) else (False,)

terminal = lambda (player_sum, dealer_sum, usable_ace, terminal): terminal

random_card = lambda: random_choice(range(1, 1+10) + [10]*3)

def R(s, a):
    player_sum, dealer_sum, usable_ace, terminal = s
    if a: # card
        card = random_card()
        if card == 1:
            usable_ace = True
        player_sum += card
        if player_sum > 21:
            terminal = True
            r = -1
        else:
            terminal = False
            r = 0
        s2 = player_sum, dealer_sum, usable_ace, terminal
        return s2, r
    else:
        terminal = True
        # use ace?
        if usable_ace and (player_sum + 10 <= 21):
            player_sum += 10
        s2 = player_sum, dealer_sum, False, terminal
        # draw dealer cards
        while dealer_sum < 17:
            dealer_sum += random_card()
        if dealer_sum > 21:
            r = +1
        else:
            if (player_sum > dealer_sum) and (player_sum <= 21):
                # player wins
                r = +1
            elif player_sum == dealer_sum:
                r = 0
            else:
                r = -1
        return s2, r

initial_guess = lambda (player_sum, dealer_sum, usable_ace, terminal): (0 if terminal or player_sum > 21 else 0.5)
Q0 = { (s, a): initial_guess(s) for s in S for a in A(s) }
policy0 = { (player_sum, dealer_sum, usable_ace, terminal): player_sum < 17 for (player_sum, dealer_sum, usable_ace, terminal) in S }

In [62]:
Q = copy(Q0)
policy = copy(policy0)

In [72]:
%%time
mc_es(S, terminal, A, R, gamma, Q, policy, 10000000)
print len(policy)

1240
CPU times: user 4min 37s, sys: 36 ms, total: 4min 37s
Wall time: 4min 37s


In [73]:
sym = { False: '.', True: 'X' }

for usable_ace in (True, False):
    print
    print
    print "%susable ace" % ("" if usable_ace else "no ")
    for player_sum in range(11, 1+21)[::-1]:
        for dealer_sum in range(1, 1+10):
            print sym[policy[player_sum - (10 if usable_ace else 0), dealer_sum, usable_ace, False]],
        print player_sum
    for dealer_sum in range(1, 1+10):
        print dealer_sum % 10,



usable ace
. . . . . . . . . . 21
. . . . . . . . . . 20
. . . . . . . . . . 19
. . . . . . . . X X 18
X X X X X X X X X X 17
X X X X X X X X X X 16
X X X X X X X X X X 15
X X X X X X X X X X 14
X X X X X X X X X X 13
X X X X X X X X X X 12
X X X X X X X X X X 11
1 2 3 4 5 6 7 8 9 0

no usable ace
. . . . . . . . . . 21
. . . . . . . . . . 20
. . . . . . . . . . 19
. . . . . . . . . . 18
. . . . . . . . . . 17
. . . . . . X X X . 16
. . . . . . X X X X 15
. . . . . . X X X X 14
X . . . . . X X X X 13
X . . . . . X X X X 12
X X X X X X X X X X 11
1 2 3 4 5 6 7 8 9 0


In [81]:
for usable_ace in (True, False):
    print
    print
    print "%susable ace" % ("" if usable_ace else "no ")
    for player_sum in range(11, 1+21)[::-1]:
        player_sum -= (10 if usable_ace else 0)
        for dealer_sum in range(1, 1+10):
            s = player_sum, dealer_sum, usable_ace, False
            print "%2d" % (100*0.5*(1 + max([Q[s, a] for a in A(s)]))),
        print "|", player_sum
    for dealer_sum in range(1, 1+10):
        print "--",
    print
    for dealer_sum in range(1, 1+10):
        print "%2d" % (dealer_sum % 10),



usable ace
91 94 94 94 94 95 96 96 96 96 | 11
76 83 83 83 84 85 88 89 90 76 | 10
65 71 72 72 73 73 79 80 68 51 | 9
52 58 58 59 62 61 70 59 47 43 | 8
45 51 52 53 55 57 54 47 44 39 | 7
44 49 51 51 54 56 49 46 44 38 | 6
45 51 52 53 56 56 52 49 45 40 | 5
46 52 53 55 56 58 54 52 48 42 | 4
47 54 55 55 57 59 55 53 48 43 | 3
50 54 57 56 57 60 59 57 52 45 | 2
64 70 71 70 71 74 73 71 67 62 | 1
-- -- -- -- -- -- -- -- -- --
 1  2  3  4  5  6  7  8  9  0

no usable ace
91 94 94 94 94 95 96 96 96 96 | 21
76 83 82 83 84 84 88 88 89 76 | 20
63 70 71 73 74 74 79 81 68 52 | 19
51 57 58 59 61 62 71 59 43 41 | 18
36 44 45 47 49 51 47 33 31 29 | 17
29 36 39 41 42 44 29 27 25 22 | 16
30 36 39 41 43 45 32 29 27 24 | 15
30 38 38 39 43 45 33 31 29 25 | 14
31 37 39 41 42 44 36 33 32 28 | 13
34 38 39 41 43 44 39 36 34 30 | 12
57 62 63 64 66 68 64 62 58 55 | 11
-- -- -- -- -- -- -- -- -- --
 1  2  3  4  5  6  7  8  9  0
