In [3]:
import torch
import random
import torch.nn as nn
import matplotlib.pyplot as plt

from torch.optim import RMSprop
from torch.utils.data import DataLoader, TensorDataset

### uses deep q learning

In [29]:
DECK = 4 * (list(range(2, 12)) + [10, 10, 10])

class QFunction(nn.Module):
    def __init__(self):
        """
        (2) splittable
        (3) doublable 
        (4) soft 
        (5) score 
        (6) doublable2 
        (7) soft2 
        (8) score2 
        (9) dealer upcard 
        (...) true count
        (10) current hand being processed
        (11) action
        """
        super().__init__()
        self.lin1 = nn.Linear(11, 12)
        self.lin2 = nn.Linear(12, 12)
        self.lin3 = nn.Linear(12, 12)
        self.lin4 = nn.Linear(12, 1)
        

    def forward(self, x):
        x = self.lin1(x)
        x = torch.relu(x)
        x = self.lin2(x)
        x = torch.relu(x)
        x = self.lin3(x)
        x = torch.relu(x)
        return self.lin4(x)

class ReplayMemory(nn.Module):
    def __init__(self, N = 30, mini_batch = 7):
        self.memory = []
        self.N = N
        self.mini_batch = mini_batch

    def push(self, tup):
        self.memory.append(tup)
        
        if len(self.memory) > self.N:
            self.memory = self.memory[1:]

    def sample(self):
        return random.sample(self.memory, min(self.memory.__len__(), self.mini_batch))

In [5]:
def get_true_count(count, shoe):
    """returns the true count exactly based on the running count and length of the shoe"""
    return round(count / round(len(shoe) / 52))

def update_count(count, card):
    """Adds -1, 0, or +1 to the count given the card value"""
    if card <= 6:
        return count + 1
    elif card >= 10:
        return count - 1

    return count

def get_multiplier(true_count):
    return 1. + (1. * max(0, true_count))

def is_soft(c1, c2):
    return c1 == 11 or c2 == 11

def get_score(hand, c, soft):
    """returns the score and its softness given the existing hand and a new card"""
    if c != 11 and not soft:
        return False, hand + c

    if c == 11 and not soft: # if valid with 1 or 11, then is still soft (and add 11). else add 1 only
        if hand + 11 <= 21: # and by extension vaild with hand + 1
            return True, hand + 11

        return False, hand + 1

    if c != 11 and soft: # if total > 21, then subtract 10 and it's no longer soft. else, add the card and it's still soft 
        if hand + c > 21:
            return False, hand + c - 10

        return True, hand + c

    # else c == 11 and it's soft
    # my argument: we add one and it remains soft since by construction, a hand cannot be soft
    # unless it is of total at least 11
    return True, hand + 1

def blackjack(c1, c2):
    return (c1 == 11 and c2 == 10) or (c1 == 10 and c2 == 11)

In [6]:
### TROUBLESHOOTING GET SCORE
#for p in range(2, 21): # isch guet
#    for c in range(2, 12):
#        s, np = get_score(p, c, False)
#        print(p, 'plus', c, 'is', np, s)

#for p in range(11, 21): # i think it's good but i don't want to check rn :)
#    for c in range(2, 12):
#        s, np = get_score(p, c, True)
#        print(p, 'plus', c, 'is', np, s)

In [53]:
def update_q(Q, D):
    sample = D.sample()
        
    phi_j = torch.stack([x[0] for x in sample])
    a_j = torch.tensor([x[1] for x in sample]).reshape(-1, 1)
    r_j = torch.tensor([x[2] for x in sample]).reshape(-1, 1)
    phi_j_plus_one = torch.stack([x[3] for x in sample])

    # for phi j and phi j + 1, need to preprocess things so the range makes sense
    phi_j[:,3] /= 21 # player score
    phi_j_plus_one[:,3] /= 21

    phi_j[:,6] /= 21 # player 2 score
    phi_j_plus_one[:,6] /= 21

    phi_j[:,7] /= 11 # dealer score
    phi_j_plus_one[:,7] /= 11

    phi_j[:,8] /= 6 # true count score
    phi_j_plus_one[:,8] /= 6

    y_j = 69. * torch.ones_like(a_j) # 69 is a placeholder to make sure the rewards propagate correctly
    is_terminal = torch.all(phi_j_plus_one == 0., axis = -1)

    y_j[is_terminal] = r_j[is_terminal]

    for i in torch.arange(D.mini_batch)[~is_terminal]:
        nonterminal_state_action_input_block = phi_j_plus_one[i].repeat(4).reshape(4, -1)
        actions = torch.arange(4).reshape(-1, 1)
        nonterminal_state_action_input_block = torch.hstack((nonterminal_state_action_input_block, actions))

        q_values = Q(nonterminal_state_action_input_block).detach()
        max_q = q_values.max()

        y_j[i] = max_q

    ## now perform the gradient descent step :P
    optimizer = RMSprop(Q.parameters(), lr = 0.001)

    for _ in range(10):
        optimizer.zero_grad()
    
        q_values = Q(torch.hstack((phi_j, a_j)))
    
        loss = torch.sum((q_values - y_j)**2)
    
        loss.backward()
        optimizer.step()

In [67]:
def do_game(qfunction, memory, verbose = False, train = False):
    """
    does rounds of blackjack until the shoe is sufficiently low (here defined to be 20 cards) 
    """
    shoe = 8 * DECK
    random.shuffle(shoe)

    count = 0.
    winnings = 0.

    while len(shoe) > 60: # must not be less than 26 so rounding works

        true_count = get_true_count(count, shoe)   
        multiplier = get_multiplier(true_count)

        if verbose:
            print('\n=== NEW ROUND ===')
            print('True Count:', true_count, '| Bet:', multiplier * 10)
              
        reward, count = blackjack_round(qfunction, memory, shoe, 1., count, verbose)

        if train:
            update_q(Q, D)

        winnings += (multiplier * reward * 10.)

    return winnings

def blackjack_round(qfunction, memory, shoe, bet, count, verbose = False):
    """
    one round of blackjack
    """
    
    epsilon = 0.1 # probability we don't choose the greedy action
    penalty = -4. # punishment for choosing an invalid action
    bet2 = 0. # initialized in case we case to split
    terminal_state = torch.zeros(10)
    
    # deal player cards
    c1 = shoe.pop(0)
    c2 = shoe.pop(0)

    count = update_count(count, c1)
    count = update_count(count, c2)

    player = 0.
    dealer = 0.
    
    splittable = c1 == c2

    soft, player = get_score(player, c1, False)
    soft, player = get_score(player, c2, soft)

    # deal dealer cards
    d1 = shoe.pop(0)
    d2 = shoe.pop(0)

    count = update_count(count, d1)

    soft_dealer, dealer = get_score(dealer, d1, False)

    # initiallize state vector
    # splittable | doublable | soft | score | doublable2 | soft2 | score2 | dealer upcard | true count | hand in question
    state = torch.tensor(
        [
            splittable, # can split or not
            True, # can double 
            soft, # soft or not
            player, # player score 
            False, False, 0.,
            dealer,
            get_true_count(count, shoe),
            0.
        ])


    # check for blackjacks
    if blackjack(d1, d2) and not blackjack(c1, c2): # dealer, not player
        if verbose:
            print('Dealer blackjack')
            print('Payout:', -1 * bet)

        count = update_count(count, d2)
            
        return -1 * bet, count

    elif blackjack(d1, d2) and blackjack(c1, c2): # both
        if verbose:
            print('Double blackjack')
            print('Payout:', 0.)

        count = update_count(count, d2)
            
        return 0., count

    elif blackjack(c1, c2):
        if verbose:
            print('Blackjack!')
            print('Payout:', 1.5 * bet)

        count = update_count(count, d2)
        
        return 1.5 * bet, count

    # print what's up
    if verbose:
        print('Player:', 'soft' if soft else '', player, '| Dealer:', d1, '| Running Count:', count)


    # player gets to 21 by doing actions
    while state[3] < 21: 

        SARS1 = [None, None, None, None] 
        
        # get action - 0 stand, 1 hit, 2 double, 3 split
        state_action_input_block = torch.stack([state, state, state, state])
        actions = torch.arange(4).reshape(-1, 1)
        state_action_input_block = torch.hstack((state_action_input_block, actions))

        q_values = qfunction(state_action_input_block)
        q_max = q_values.argmax().item()

        alternative_actions = [0, 1, 2, 3]
        alternative_actions.remove(q_max)
        
        if torch.rand(1) < epsilon:
            action = random.sample(alternative_actions, 1)[0]

        else:
            action = q_max

        # update tuple
        SARS1 = [state, action, None, None]
        state = state.clone()
            

        if verbose:
            print('Action:', ['stand', 'hit', 'double', 'split'][action])

        if action == 0:
            break

        if action == 1:
            c = shoe.pop(0)
            count = update_count(count, c)
            
            state[2], state[3] = get_score(state[3], c, state[2])
            state[8] = get_true_count(count, shoe)

            # update so no splits or doubles
            state[0] = False
            state[1] = False

            if verbose:
                print('Player:', 'soft' if state[2] else '', state[3].item())

            if state[3] < 21:
                SARS1[2] = 0.
                SARS1[3] = state
    
                memory.push(SARS1)


        if action == 2:
            if not state[1]: # if we doubled when we were not permitted
                if verbose:
                    print('Invalid!')

                SARS1[2] = penalty
                SARS1[3] = terminal_state

                memory.push(SARS1)

                return penalty, count

            # else it's just the same as above
            c = shoe.pop(0)
            count = update_count(count, c)
            
            state[2], state[3] = get_score(state[3], c, state[2])
            state[8] = get_true_count(count, shoe)

            # update so no splits or doubles
            state[0] = False
            state[1] = False

            # double the bet
            bet *= 2

            if verbose:
                print('Player:', 'soft' if state[2] else '', state[3].item())

            break

        if action == 3:
            if not state[0]:
                if verbose:
                    print('Invalid!')

                SARS1[2] = penalty
                SARS1[3] = terminal_state

                memory.push(SARS1)

                return penalty, count

            # can no longer split
            state[0] = False

            # redistribute cards
            if state[2] and state[3] == 12: # if double aces
                state[3] = 11.
                state[6] = 11.

                state[2] = True
                state[5] = True

            else: # otherwise it's simplly half the score
                state[3] /= 2
                state[6] = state[3]

                state[2] = False
                state[5] = False

            # add two new cards
            c = shoe.pop(0)
            count = update_count(count, c)
            state[2], state[3] = get_score(state[3], c, state[2])

            c = shoe.pop(0)
            count = update_count(count, c)
            state[5], state[6] = get_score(state[6], c, state[5])

            state[8] = get_true_count(count, shoe)

            # update doubleables
            state[1] = True
            state[4] = True

            # instantiate the second bet
            bet2 = 1. * bet

            if verbose:
                print('Player:', 'soft' if state[2] else '', state[3].item(), '| Player 2:', 'soft' if state[5] else '', state[6].item())

            SARS1[2] = 0.
            SARS1[3] = state

            memory.push(SARS1)



    # see if the second hand must be processed
    if state[6] != 0:

        state[-1] = 1.

        # push the state here
        SARS1[2] = 0.
        SARS1[3] = state

        memory.push(SARS1)
        
        while state[6] < 21:

            SARS1 = [None, None, None, None] 

            # select action again
            state_action_input_block = torch.stack([state, state, state, state])
            actions = torch.arange(4).reshape(-1, 1)
            state_action_input_block = torch.hstack((state_action_input_block, actions))
    
            q_values = qfunction(state_action_input_block)
            q_max = q_values.argmax().item()
    
            alternative_actions = [0, 1, 2, 3]
            alternative_actions.remove(q_max)
    
            
            if torch.rand(1) < epsilon:
                action = random.sample(alternative_actions, 1)[0]
    
            else:
                action = q_max
    
            # update tuple
            SARS1 = [state, action, None, None]
            state = state.clone()
            
    
            if verbose:
                print('Action 2:', ['stand', 'hit', 'double', 'split'][action])
    
            if action == 0:
                break
    
            if action == 1:
                c = shoe.pop(0)
                count = update_count(count, c)
                
                state[5], state[6] = get_score(state[6], c, state[5])
                state[8] = round(count / round(len(shoe) / 52))
    
                # update so no doubles
                state[4] = False
    
                if verbose:
                    print('Player 2:', 'soft' if state[5] else '', state[6].item())

                if state[6] < 21:
                    SARS1[2] = 0.
                    SARS1[3] = state
        
                    memory.push(SARS1)

    
            if action == 2:
                if not state[4]: # if we doubled when we were not permitted
                    if verbose:
                        print('Invalid!')

                    SARS1[2] = penalty
                    SARS1[3] = terminal_state
        
                    memory.push(SARS1)
    
                    return penalty, count
    
                # else it's just the same as above
                c = shoe.pop(0)
                count = update_count(count, c)
                
                state[5], state[6] = get_score(state[6], c, state[5])
                state[8] = round(count / round(len(shoe) / 52))
    
                # update so no doubles
                state[4] = False
                bet2 *= 2
    
                if verbose:
                    print('Player 2:', 'soft' if state[5] else '', state[6].item())
    
                break
    
            if action == 3:
                if verbose:
                    print('Invalid!')

                SARS1[2] = penalty
                SARS1[3] = terminal_state
    
                memory.push(SARS1)
    
                return penalty, count



    # we may not have to process the dealer hand
    # (1) hand 1 busts and hand 2 doesn't exist ; (2) both hands exist, and both bust
    if state[3] > 21 and state[6] == 0:
        if verbose:
            print('Bust. Payout:', -1 * bet)

        SARS1[2] = -1 * bet
        SARS1[3] = terminal_state

        memory.push(SARS1)

        soft_dealer, dealer = get_score(dealer, d2, soft_dealer)

        return -1 * bet, count

    if state[3] > 21 and state[6] > 21:
        if verbose:
            print('Both hands bust. Payout:', -1 * (bet + bet2))

        SARS1[2] = -1 * (bet + bet2)
        SARS1[3] = terminal_state

        memory.push(SARS1)

        soft_dealer, dealer = get_score(dealer, d2, soft_dealer)
        return -1 * (bet + bet2), count

    
    # otherwise, we must process the dealer hand
    soft_dealer, dealer = get_score(dealer, d2, soft_dealer)
    count = update_count(count, d2)

    while dealer < 17:
        d = shoe.pop(0)
        count = update_count(count, d)
        
        soft_dealer, dealer = get_score(dealer, d, soft_dealer)

    if verbose:
        print('Dealer:', dealer)

    
    # check for dealer bust
    if dealer > 21:
        if verbose:
            print('Player victory. Payout:', (2 * (state[3].item() <= 21) - 1) * bet + (2 * (state[6].item() <= 21) - 1) * bet2)

        SARS1[2] = ((2 * (state[3].item() <= 21) - 1) * bet) + ((2 * (state[6].item() <= 21) - 1) * bet2)
        SARS1[3] = terminal_state

        memory.push(SARS1)

        return ((2 * (state[3].item() <= 21) - 1) * bet) + ((2 * (state[6].item() <= 21) - 1) * bet2), count


    # do comparisons
    # only first hand ~OR~ first hand and second hand bust
    if state[3] <= 21 and (state[6] == 0 or state[6] > 21): # then it's win loss or tie (already accountued for bust case)
        if state[3] > dealer:
            if verbose:
                print('Payout:', bet - bet2)

            SARS1[2] = bet - bet2
            SARS1[3] = terminal_state
    
            memory.push(SARS1)

            return bet - bet2, count

        elif state[3] == dealer:
            if verbose:
                print('Payout:', 0. - bet2)

            SARS1[2] = 0 - bet2
            SARS1[3] = terminal_state
    
            memory.push(SARS1)

            return 0. - bet2, count

        else: 
            if verbose:
                print('Payout:', -1 * bet - bet2)

            SARS1[2] = -1 * bet - bet2
            SARS1[3] = terminal_state
    
            memory.push(SARS1)

            return -1 * bet - bet2, count

    # first hand bust, second hand valid
    if state[3] > 21 and state[6] <= 21:
        if state[6] > dealer:
            if verbose:
                print('Payout:', bet2 - bet)

            SARS1[2] = bet2 - bet
            SARS1[3] = terminal_state
    
            memory.push(SARS1)

            return bet - bet2, count

        elif state[6] == dealer:
            if verbose:
                print('Payout:', 0. - bet)

            SARS1[2] = 0 - bet
            SARS1[3] = terminal_state
    
            memory.push(SARS1)

            return 0. - bet, count

        else: 
            if verbose:
                print('Payout:', -1 * bet2 - bet)

            SARS1[2] = -1 * bet - bet2
            SARS1[3] = terminal_state
    
            memory.push(SARS1)

            return -1 * bet2 - bet, count

    # both hands valid
    if state[3] == dealer:
        if state[6] == dealer:
            if verbose:
                print('Payout:', 0.)

            SARS1[2] = 0.
            SARS1[3] = terminal_state
    
            memory.push(SARS1)
                
            return 0., count
            
        # else we won or lost
        if verbose:
            print('Payout:', bet2 * (2 * (state[6] > dealer).item() - 1))

        SARS1[2] = bet2 * (2 * (state[6] > dealer).item() - 1)
        SARS1[3] = terminal_state

        memory.push(SARS1)

        return bet2 * (2 * (state[6] > dealer).item() - 1), count

    elif state[6] == dealer:
        # then just account for first hand, already accounted for equality case
        if verbose:
            print('Payout:', bet * (2 * (state[3] > dealer) - 1).item())

        SARS1[2] = bet * (2 * (state[3] > dealer) - 1).item()
        SARS1[3] = terminal_state

        memory.push(SARS1)
            
        return bet * (2 * (state[3] > dealer) - 1).item(), count

    # else both are not ties and eligible for consideration
    if verbose:
        print('Payout:', bet * (2 * (state[3] > dealer) - 1).item() + (bet2 * (2 * (state[6] > dealer) - 1).item()))

    SARS1[2] = bet * (2 * (state[3] > dealer) - 1).item() + (bet2 * (2 * (state[6] > dealer) - 1)).item()
    SARS1[3] = terminal_state

    memory.push(SARS1)

    return bet * (2 * (state[3] > dealer) - 1).item() + (bet2 * (2 * (state[6] > dealer) - 1)).item(), count


In [68]:
Q = QFunction()
D = ReplayMemory()

do_game(Q, D, verbose = True, train = False)


=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  4.0 | Dealer: 5 | Running Count: 3.0
Action: stand
Dealer: 18.0
Payout: -1.0

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  13.0 | Dealer: 8 | Running Count: 3.0
Action: stand
Dealer: 18.0
Payout: -1.0

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Player:  12.0 | Dealer: 8 | Running Count: 5.0
Action: split
Invalid!

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Player:  6.0 | Dealer: 10 | Running Count: 6.0
Action: stand
Dealer: 19.0
Payout: -1.0

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Player: soft 12.0 | Dealer: 7 | Running Count: 4.0
Action: stand
Dealer: 22.0
Player victory. Payout: 1.0

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  13.0 | Dealer: 5 | Running Count: 5.0
Action: stand
Dealer: 17.0
Payout: -1.0

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Player:  18.0 | Dealer: 9 | Running Count: 6.0
Action: hit
Player:  21.0
Dealer: 18.0
Payout: 1.0

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Player:  14.0 | Dea

-610.0

In [69]:
#for x in D.memory:
#    print(x)
#    print()

In [70]:
EPOCHS = 300

for epoch in range(EPOCHS):
    w = do_game(Q, D, verbose = False, train = True)

    if epoch % 10 == 0:
        print('=== Epoch', epoch, '| Payout:', w, '===')

=== Epoch 0 | Payout: -1375.0 ===
=== Epoch 10 | Payout: -250.0 ===
=== Epoch 20 | Payout: -1175.0 ===
=== Epoch 30 | Payout: -180.0 ===
=== Epoch 40 | Payout: -425.0 ===
=== Epoch 50 | Payout: -1075.0 ===
=== Epoch 60 | Payout: -460.0 ===
=== Epoch 70 | Payout: -395.0 ===
=== Epoch 80 | Payout: -860.0 ===
=== Epoch 90 | Payout: -515.0 ===
=== Epoch 100 | Payout: -235.0 ===
=== Epoch 110 | Payout: -545.0 ===
=== Epoch 120 | Payout: -180.0 ===
=== Epoch 130 | Payout: -1060.0 ===
=== Epoch 140 | Payout: -195.0 ===
=== Epoch 150 | Payout: -125.0 ===
=== Epoch 160 | Payout: -1295.0 ===
=== Epoch 170 | Payout: -725.0 ===
=== Epoch 180 | Payout: -850.0 ===
=== Epoch 190 | Payout: -175.0 ===
=== Epoch 200 | Payout: -560.0 ===
=== Epoch 210 | Payout: -580.0 ===
=== Epoch 220 | Payout: -255.0 ===
=== Epoch 230 | Payout: -305.0 ===
=== Epoch 240 | Payout: -530.0 ===
=== Epoch 250 | Payout: -550.0 ===
=== Epoch 260 | Payout: -365.0 ===
=== Epoch 270 | Payout: -1420.0 ===
=== Epoch 280 | Payout: -

## visualize a learned game

In [71]:
do_game(Q, D, verbose = True, train = False)


=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  13.0 | Dealer: 6 | Running Count: 2.0
Action: stand
Dealer: 21.0
Payout: -1.0

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player: soft 13.0 | Dealer: 5 | Running Count: 3.0
Action: split
Invalid!

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  18.0 | Dealer: 10 | Running Count: 1.0
Action: stand
Dealer: 20.0
Payout: -1.0

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  20.0 | Dealer: 2 | Running Count: -1.0
Action: stand
Dealer: 18.0
Payout: 1.0

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  8.0 | Dealer: 10 | Running Count: 2.0
Action: hit
Player:  18.0
Action: double
Invalid!

=== NEW ROUND ===
True Count: 0 | Bet: 10.0
Player:  9.0 | Dealer: 7 | Running Count: 3.0
Action: stand
Dealer: 25.0
Player victory. Payout: 1.0

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Player:  15.0 | Dealer: 5 | Running Count: 6.0
Action: stand
Dealer: 19.0
Payout: -1.0

=== NEW ROUND ===
True Count: 1 | Bet: 20.0
Blackjack!
Payout: 

-115.0

## some obvious test cases to see if the model is learning
### Hand: hard 20 on 6 with P true and D irrelevant
### Policy: should stand
### Value: should be basically 10 (if doublable, since this indicates 20 was dealt initially)

In [72]:
"""
(2) splittable
(3) doublable 
(4) soft 
(5) score 
(6) doublable2 
(7) soft2 
(8) score2 
(9) dealer upcard 
(...) true count
(10) current hand being processed
"""

test1 = torch.tensor([1., 0., 0., 20., 0., 0., 0., 6., 0., 0.]).repeat(4).reshape(4, -1)
actions = torch.arange(4).reshape(-1, 1)
test1 = torch.hstack((test1, actions))


with torch.no_grad():
    print('Test 1')
    print('Q-Values:', Q(test1)) # get action - 0 stand, 1 hit, 2 double, 3 split


test2 = torch.tensor([1., 1., 0., 20., 0., 0., 0., 6., 0., 0.]).repeat(4).reshape(4, -1)
actions = torch.arange(4).reshape(-1, 1)
test2 = torch.hstack((test2, actions))

with torch.no_grad():
    print('')
    print('Test 2')
    print('Q-Values:', Q(test2))


Test 1
Q-Values: tensor([[-0.5175],
        [-0.5175],
        [-0.5175],
        [-0.5175]])

Test 2
Q-Values: tensor([[-0.5175],
        [-0.5175],
        [-0.5175],
        [-0.5175]])


### Hand: hard 11 on 8 with P false and D true
### Policy: should double

In [73]:
test3 = torch.tensor([0., 1., 0., 11., 0., 0., 0., 8., 0., 0.]).repeat(4).reshape(4, -1)
actions = torch.arange(4).reshape(-1, 1)
test3 = torch.hstack((test3, actions))

with torch.no_grad():
    print('')
    print('Test 3')
    print('Q-Values:', Q(test3))


Test 3
Q-Values: tensor([[-0.5175],
        [-0.5175],
        [-0.5175],
        [-0.5175]])


### Hand: soft 19 on 4 with P true and D true
### Policy: should stand

In [74]:
test4 = torch.tensor([1., 1., 1., 19., 0., 0., 0., 4., 0., 0.]).repeat(4).reshape(4, -1)
actions = torch.arange(4).reshape(-1, 1)
test4 = torch.hstack((test4, actions))

with torch.no_grad():
    print('')
    print('Test 4')
    print('Q-Values:', Q(test4))


Test 4
Q-Values: tensor([[-0.5175],
        [-0.5175],
        [-0.5175],
        [-0.5175]])


### Hand: hard 16 on 3 with P true and D true
### Policy: should split (since 16 is a crappy hand)

In [62]:
test5 = torch.tensor([1., 1., 1., 16., 0., 0., 0., 3., 0., 0.]).repeat(4).reshape(4, -1)
actions = torch.arange(4).reshape(-1, 1)
test5 = torch.hstack((test5, actions))

with torch.no_grad():
    print('')
    print('Test 5')
    print('Q-Values:', Q(test5))


Test 5
Q-Values: tensor([[55.0492],
        [52.6395],
        [50.2298],
        [47.8200]])
