In [1]:
from classes.cards import *
from classes.policies import *

import torch
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt

from torch.distributions import Poisson, Categorical, Binomial
from torch.utils.data import DataLoader, TensorDataset

## We need to train three networks
### 1 betting network - how much do we bet based on the count
### 2 counting network - how do we count in the first place
### 3 action network - do we hit, stand, double, et cetera

In [44]:
def print_things(verbose, *args):
    if verbose:
        for a in args:
            print(a)

def encode_state_action(hand, dealer_upcard, count): # h2, dh.upcard().oho, count
    footprint_numbers = [c.oho for c in hand.cards]
    stacked = torch.vstack(footprint_numbers).sum(axis = 0).reshape(1, -1)

    return torch.cat((stacked, dealer_upcard.reshape(1, -1), count), 1)

def encode_state_bet(budget, count, shoe_estimate):
    return torch.cat((budget, count, shoe_estimate), 1)

def encode_state_counter(card_one_hot, current_score):
    return torch.cat((card_one_hot.reshape(1, 13), current_score.reshape(1, 1)), 1)

def single_full_deck(POT, 
                     value_bet, policy_bet, 
                     value_game, policy_game, 
                     value_counter, policy_counter, 
                     hist_bet = None, hist_game = None, hist_counter = None, # these are are state/action/reward tuples
                     verbose = True):
    
    d = Decks(8)
    time_series = []
    count = torch.zeros(1, 1)

    punishment = -2000 # how much we take away if going below the budget

    if hist_bet == None:
        hist_bet = HistoryBet()

    if hist_counter == None:
        hist_counter = HistoryCounter()

    if hist_game == None:
        hist_game = HistoryGame()
        
    
    while (len(d) > d.shoe_cutoff):

        print_things(verbose, '')

        print_things(verbose, 'total remaining', POT)
        time_series.append(POT)

        """
        (1) place an initial bet...
        """
        # POT + count + len.d (noisy estimate)
        state_bet = encode_state_bet(torch.tensor([[POT],]), count, len(d) + 26 * torch.randn(1, 1))
        prob_bet, x_bet, bet_raw = policy_bet(state_bet) 
        
        bet = 10. #bet_raw.item() """CHANGE THIS BACK LATEER!!!!"""
        bet2 = 1. * bet # if there is a second that we must process

        # let's think harder about this one since we mayy need to change thigns
        # print('betting network | action:', bet, 'state:', state)
        print_things(verbose, 'initial bet', bet)

        series_of_counts_to_factor_reward_at_end = []
        series_of_actions_to_factor_reward_at_end = []

        """
        (2) deal hands
        """
        dh = Hand()
        h = Hand()
        h2 = None

        h1_busted = False
        h2_busted = False

        # dealer upcard
        c = d.deal()
        dh.add_card(c)
        
        counter_state = encode_state_counter(c.oho, count)
        count_probs, count_add = policy_counter(counter_state)
        count += count_add
        t = (counter_state, count_probs, count_add)
        series_of_counts_to_factor_reward_at_end.append(t)

        # dealer downcard (we add it but don't count it yet 
        c_down = d.deal()
        dh.add_card(c_down)

        for _ in range(2): # player
            c = d.deal()
            h.add_card(c)
            
            counter_state = encode_state_counter(c.oho, count)
            _, count_add = policy_counter(counter_state)
            count += count_add
            t = (counter_state, count_probs, count_add)
            series_of_counts_to_factor_reward_at_end.append(t)


        """
        (3) check for blackjack(s)
        """
        if dh.points().max() == 21 and h.points().max() != 21:
            print_things(verbose, 'dealer blackjack', '==> player', h, '==> dealer', dh, '-' + str(bet))
            POT -= bet

            counter_state = encode_state_counter(c_down.oho, count)
            count_probs, count_add = policy_counter(counter_state)
            count += count_add
            t = (counter_state, count_probs, count_add)
            series_of_counts_to_factor_reward_at_end.append(t)
        
            hist_bet.add(state_bet, prob_bet, x_bet, (-1 * bet) + punishment * (POT <= 0 and POT + bet > 0)) # state, action, reward
            hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, -1 * bet)
            continue

        if dh.points().max() == 21 and h.points().max() == 21:
            print_things(verbose, 'double blackjack', '==> player', h, '==> dealer', dh, '+' + str(0))
            POT += 0

            counter_state = encode_state_counter(c_down.oho, count)
            count_probs, count_add = policy_counter(counter_state)
            count += count_add
            t = (counter_state, count_probs, count_add)
            series_of_counts_to_factor_reward_at_end.append(t)
            
            hist_bet.add(state_bet, prob_bet, x_bet, 0 * bet) 
            hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, 0 * bet)
            continue
            
        if h.points().max() == 21:
            print_things(verbose, 'blackjack!! ==> player', h, '==> dealer', dh, '+' + str(1.5 * bet))
            POT += (1.5 * bet)

            counter_state = encode_state_counter(c_down.oho, count)
            count_probs, count_add = policy_counter(counter_state)
            count += count_add
            t = (counter_state, count_probs, count_add)
            series_of_counts_to_factor_reward_at_end.append(t)
            
            hist_bet.add(state_bet, prob_bet, x_bet, 1.5 * bet) 
            hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, 1.5 * bet)
            continue
            
    
        """
        (4) do actions
        """
        print_things(verbose, 'start game', 'player:', h,' | dealer:', dh.upcard())
        
        while h.points().min() < 21:
            state_action = encode_state_action(h, dh.upcard().oho, count)
            actions = policy_game(state_action) # right now, it is just the cards, we are not counting
            actions_probs = actions.detach()

            my_action = Categorical(probs = actions_probs[0]).sample().item()
            # 0 - split if possible -> double if possible -> hit
            # 1 - split if possible -> do not double -> hit
            # 2 - dont split -> double if possible -> hit
            # 3 - dont split -> do not double -> hit
            # 4 - 7 (same thing but with double)

            if (my_action == 0) or (my_action == 1) or (my_action == 4) or (my_action == 5):
                if h.can_split and torch.all(h.card_i(0) == h.card_i(1)): # split
                    print_things(verbose, 'splitting...')
                    c = h.take() 
    
                    h2 = Hand()
                    h2.add_card(c) # already saw the card
    
                    c = d.deal()
                    h2.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
    
                    c = d.deal()
                    h.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    h.can_split = False
                    h2.can_split = False
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)

                elif ((my_action == 0) or (my_action == 4)) and (len(h) == 2): # double
                    c = d.deal()
                    h.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
                    print_things(verbose, 'doubling...')
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)
                    
                    bet *= 2
                    break

                elif (my_action == 0) or (my_action == 1): # hit
                    c = d.deal()
                    h.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
                    print_things(verbose, 'hitting...')
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)

                else: # stand
                    print_things(verbose, 'standing...')
                
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)
                    break

            else:
                if (my_action % 2 == 0) and (len(h) == 2): # double
                    c = d.deal()
                    h.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
                    print_things(verbose, 'doubling...')
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)
                    
                    bet *= 2
                    break

                elif (my_action == 2) or (my_action == 3): # hit
                    c = d.deal()
                    h.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
                    print_things(verbose, 'hitting...')
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)

                else: # stand
                    print_things(verbose, 'standing...')
                
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)
                    break
                

        """
        (4.5) do actions on the second hand, if we have one. above code minus potential to split
        """
        if h2 != None:
            while h2.points().min() < 21:

                state_action = encode_state_action(h2, dh.upcard().oho, count)
                actions = policy_game(state_action) # right now, it is just the cards, we are not counting
                actions_probs = actions.detach()
    
                my_action = Categorical(probs = actions_probs[0]).sample().item()

                if ((my_action == 2) or (my_action == 6)) and (len(h2) == 2): # double
                    c = d.deal()
                    h2.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
                    print_things(verbose, 'doubling second hand...')
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)
                    
                    bet2 *= 2
                    break

                elif my_action <= 3: # hit
                    c = d.deal()
                    h2.add_card(c)
                    counter_state = encode_state_counter(c.oho, count)
                    count_probs, count_add = policy_counter(counter_state)
                    count += count_add
                    print_things(verbose, 'hitting second hand...')
    
                    t = (counter_state, count_probs, count_add)
                    series_of_counts_to_factor_reward_at_end.append(t)
    
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)

                else: # stand
                    print_things(verbose, 'standing second hand...')
                
                    t = (state_action, actions_probs, my_action)
                    series_of_actions_to_factor_reward_at_end.append(t)
                    break

        """
        (5) check if we're here because we busted
        """
        if h.points().min() > 21:
            print_things(verbose, 'bust', '==> player', h, '==> dealer', dh, '-' + str(bet))
            POT -= bet
            hist_bet.add(state_bet, prob_bet, x_bet, (-1 * bet) + punishment * (POT <= 0  and POT + bet > 0))
            
            h1_busted = True

        if (h2 != None) and (h2.points().min() > 21):
            print_things(verbose, 'second hand bust', '==> player', h2, '==> dealer', dh, '-' + str(bet2))
            POT -= bet2
            hist_bet.add(state_bet, prob_bet, x_bet, (-1 * bet2) + punishment * (POT <= 0 and POT + bet2 > 0))

            h2_busted = True

        if h1_busted and (h2 == None or h2_busted):
            # dealer exposes upcard
            counter_state = encode_state_counter(c_down.oho, count)
            count_probs, count_add = policy_counter(counter_state)
            count += count_add
            t = (counter_state, count_probs, count_add)
            series_of_counts_to_factor_reward_at_end.append(t)

            hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, -1 * bet)
            hist_game.add_batch(series_of_actions_to_factor_reward_at_end, -1 * bet)

            if (h2 != None) and h2_busted:
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, -1 * bet2)
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, -1 * bet2)
            
            continue

        
        """
        (6) otherwise, dealer hits to 17
        """
        less_than_21_set = (dh.points() < 21)
        less_than_21_array = dh.points()[less_than_21_set]

        counter_state = encode_state_counter(c_down.oho, count) # dealer flips their downcard :)
        count_probs, count_add = policy_counter(counter_state)
        count += count_add
        t = (counter_state, count_probs, count_add)
        series_of_counts_to_factor_reward_at_end.append(t)
        
        while dh.points().min() < 21:

            if less_than_21_array.shape != 0:
                if less_than_21_array.max() > 17:
                    break # the second condition
                    
            c = d.deal()
            dh.add_card(c)
            counter_state = encode_state_counter(c.oho, count)
            count_probs, count_add = policy_counter(counter_state)
            count += count_add

            t = (counter_state, count_probs, count_add)
            series_of_counts_to_factor_reward_at_end.append(t)
            
            less_than_21_set = (dh.points() < 21)
            less_than_21_array = dh.points()[less_than_21_set]

        """
        (7) compare dealer and player, if dealer busted
        """
        if dh.points().min() > 21 and not h1_busted:
            print_things(verbose, 'dealer bust', '==> player', h, '==> dealer', dh, '+' + str(bet))
            POT += bet

            if h2 == None:
                continue

        if dh.points().min() > 21 and h2 != None:
            if not h2_busted:
                print_things(verbose, 'second hand', '==> player 2', h2, '+' + str(bet2))
                POT += bet2
                
            continue

        less_than_21_set_player = (h.points() <= 21)
        less_than_21_set_dealer = (dh.points() <= 21)

        # compare the first one
        if not h1_busted:
            if h.points()[less_than_21_set_player].max() > dh.points()[less_than_21_set_dealer].max(): # player win
                print_things(verbose, 'win', '==> player', h, '==> dealer', dh, '+' + str(bet))
                POT += bet
                hist_bet.add(state_bet, prob_bet, x_bet, 1 * bet) # state, action, reward
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, 1 * bet)
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, 1 * bet)
    
            elif h.points()[less_than_21_set_player].max() < dh.points()[less_than_21_set_dealer].max(): # dealer win
                print_things(verbose, 'loss', '==> player', h, '==> dealer', dh, '-' + str(bet))
                POT -= bet
                hist_bet.add(state_bet, prob_bet, x_bet, (-1 * bet) + punishment * (POT <= 0 and POT + bet > 0))
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, (-1 * bet))
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, (-1 * bet))
    
            else:
                print_things(verbose, 'push', '==> player', h, '==> dealer', dh, '+' + str(0))
                POT += 0
                hist_bet.add(state_bet, prob_bet, x_bet, 0 * bet)
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, 0 * bet)
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, 0 * bet)

        # compare the second one
        if h2 != None and not h2_busted:
            less_than_21_set_player_2 = (h2.points() <= 21)

            if h2.points()[less_than_21_set_player_2].max() > dh.points()[less_than_21_set_dealer].max(): # player win
                print_things(verbose, 'second hand win', '==> player 2', h2, '==> dealer', dh, '+' + str(bet2))
                POT += bet2
                hist_bet.add(state_bet, prob_bet, x_bet, 1 * bet2)
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, 1 * bet2)
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, 1 * bet2)
    
            elif h2.points()[less_than_21_set_player_2].max() < dh.points()[less_than_21_set_dealer].max(): # dealer win
                print_things(verbose, 'second hand loss', '==> player 2', h2, '==> dealer', dh, '-' + str(bet2))
                POT -= bet2
                hist_bet.add(state_bet, prob_bet, x_bet, (-1 * bet2) + punishment * (POT <= 0 and POT + bet2 > 0))
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, (-1 * bet2))
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, (-1 * bet2))
    
            else:
                print_things(verbose, 'second hand push', '==> player 2', h2, '==> dealer', dh, '+' + str(0))
                POT += 0
                hist_bet.add(state_bet, prob_bet, x_bet, 0 * bet2)
                hist_counter.add_batch(series_of_counts_to_factor_reward_at_end, 0 * bet2)
                hist_game.add_batch(series_of_actions_to_factor_reward_at_end, 0 * bet2)

        
    return {
        'time_series': time_series,
        'history_bet': hist_bet,
        'history_counter': hist_counter,
        'history_action': hist_game
    }    

def do_games(num = 20, pot = 200, verbose = False):
    value_bet = ValueBet()
    policy_bet = PolicyBet()
    
    value_game = ValueGame()
    policy_game = PolicyGame()
    
    value_counter = ValueCounter()
    policy_counter = PolicyCounter()

    hist_counter = HistoryCounter()
    hist_bet = HistoryBet()
    hist_game = HistoryGame()

    t = []

    for _ in range(num):
        sd = single_full_deck(pot, value_bet, policy_bet, value_game, policy_game, value_counter, policy_counter, 
                              hist_bet = hist_bet, hist_counter = hist_counter, hist_game = hist_game, verbose = verbose)
        
        ti = sd['time_series']
        t += ti

        pot = ti[-1]

    return t, hist_counter, hist_bet, hist_game

In [45]:
t, hc, hb, hg = do_games(2, verbose = True)


total remaining
200
initial bet
10.0
start game
player:
['4S', 'AC']
 | dealer:
8D
doubling...
dealer bust
==> player
['4S', 'AC', 'AS']
==> dealer
['8D', '2C', '2C', '5D', '5H']
+20.0

total remaining
220.0
initial bet
10.0
start game
player:
['QD', '5D']
 | dealer:
10S
doubling...
bust
==> player
['QD', '5D', '10D']
==> dealer
['10S', 'JC']
-20.0

total remaining
200.0
initial bet
10.0
start game
player:
['8S', '10H']
 | dealer:
7D
standing...
push
==> player
['8S', '10H']
==> dealer
['7D', 'AC']
+0

total remaining
200.0
initial bet
10.0
start game
player:
['8D', '6C']
 | dealer:
7C
doubling...
bust
==> player
['8D', '6C', '10S']
==> dealer
['7C', '9D']
-20.0

total remaining
180.0
initial bet
10.0
start game
player:
['3S', '7D']
 | dealer:
6D
hitting...
standing...
loss
==> player
['3S', '7D', '6D']
==> dealer
['6D', 'AH', 'KD', '4C']
-10.0

total remaining
170.0
initial bet
10.0
start game
player:
['QC', '2D']
 | dealer:
JD
standing...
dealer bust
==> player
['QC', '2D']
==> deal

## training loop -- ppo clip

In [49]:
def g(A, epsilon):
    return (1 + epsilon) * A * (A >= 0) + (1 - epsilon) * A * (A < 0)
    
def train(
    epochs,
    value_bet,
    policy_bet,
    value_counter,
    policy_counter,
    value_game,
    policy_game,
    verbose = False
         ):
    """trains a blackjack using ppo clip for a specified number of epochs"""

    batch_size = 20
    epochs = 10
    epochs2 = 100

    optimizer_value_bet = Adam(value_bet.parameters(), lr = 1e-4)
    optimizer_policy_bet = Adam(policy_bet.parameters(), lr = 1e-4)

    optimizer_value_counter = Adam(value_counter.parameters(), lr = 1e-4)
    optimizer_policy_counter = Adam(policy_counter.parameters(), lr = 1e-4)

    optimizer_value_game = Adam(value_game.parameters(), lr = 1e-4)
    optimizer_policy_game = Adam(policy_game.parameters(), lr = 1e-4)
    
    for k in range(epochs):

        hist_bet = HistoryBet()
        hist_game = HistoryGame()
        hist_counter = HistoryCounter()

        budget = 5000

        """
        GAMEPLAY
        """
        for i in range(40): # play __n__ rounds of blackjack
            sd = single_full_deck(budget, 
                 value_bet, policy_bet, 
                 value_game, policy_game, 
                 value_counter, policy_counter, 
                 hist_bet = hist_bet, hist_counter = hist_counter, hist_game = hist_game, verbose = verbose)

            budget = sd['time_series'][-1]
            if budget == 0:
                break

        """
        POLICY --- bet update
        state_bet = hist_bet.get_state()
        action_bet = hist_bet.get_action() # which is the random variable x
        reward_bet = hist_bet.get_reward()

        advantage_estimate = reward_bet - value_bet(state_bet).detach()
        round_k_log_probs = hist_bet.round_k_log_probs(policy_bet)
        n, p = hist_bet.get_binomial_values(policy_bet)

        dataset = DataLoader(TensorDataset(state_bet, advantage_estimate, round_k_log_probs, action_bet, n), batch_size = batch_size)

        for i in range(epochs2):
            #loss_policy_bet = 0

            for state, adv, lprob, x, ns in dataset:

                optimizer_policy_bet.zero_grad()
                probs_hat, _, _ = policy_bet(state)

                bern = Binomial(ns.ravel(), probs_hat.ravel())

                first_argument = torch.exp(bern.log_prob(x) - lprob) * adv.ravel()
                second_argument = g(adv, 1e-2)

                loss = -1 * torch.sum(torch.min(first_argument, second_argument)) # multiply by negative 1 so it goes up
                
                loss.backward()
                optimizer_policy_bet.step()

                #loss_policy_bet += loss.detach()
        """    
        
        """         
        POLICY --- counter update
        """
        state_counter = hist_counter.get_state()
        reward_counter = hist_counter.get_reward()
        
        advantage_estimate = reward_counter - value_counter(state_counter).detach()
        round_k_probabilities = hist_counter.round_k_probs()
        actions_index = hist_counter.get_actions_index()

        dataset = DataLoader(TensorDataset(state_counter, advantage_estimate, actions_index, round_k_probabilities), batch_size = batch_size, shuffle = True)

        for i in range(epochs2):
            loss_policy_counter = 0
            
            for state, adv, act_id, prob_k in dataset:
                N = state.shape[0]

                optimizer_policy_counter.zero_grad()
                probs, _ = policy_counter(state)
                
                first_argument = adv * probs[torch.arange(N), act_id].reshape(-1, 1) / prob_k
                second_argument = g(adv, 1e-2)
                
                loss = -1 * torch.sum(torch.min(first_argument, second_argument)) # multiply by negative 1 so it goes up
                
                loss.backward()
                optimizer_policy_counter.step()

                loss_policy_counter += loss.detach()

        """
        POLICY -- action update
        """
        
        state_action = hist_game.get_state()
        reward_action = hist_game.get_reward()

        advantage_estimate = reward_action - value_game(state_action).detach()
        round_k_probs = hist_game.round_k_probs()
        actions_index = hist_game.get_actions_index()

        dataset = DataLoader(TensorDataset(state_action, advantage_estimate, actions_index, round_k_probs), batch_size = batch_size, shuffle = True)

        for _ in range(epochs2):
            loss_policy_action = 0

            for state, adv, act_id, prob_k in dataset:
                N = state.shape[0]

                optimizer_policy_game.zero_grad()
                probs = policy_game(state)

                first_argument = adv * probs[torch.arange(N), act_id].reshape(-1, 1) / prob_k[torch.arange(N), act_id].reshape(-1, 1)
                second_argument = g(adv, 1e-2)

                loss = -1 * torch.sum(torch.min(first_argument, second_argument))
                
                loss.backward()
                optimizer_policy_game.step()
                
                loss_policy_action += -1 * loss.detach()

            #print(loss_policy_action)


        """
        VALUE --- bet update
        
        dataset = DataLoader(TensorDataset(state_bet, reward_bet), batch_size = batch_size)

        for i in range(epochs2):

            loss_value_bet = 0

            for state, reward in dataset:

                optimizer_value_bet.zero_grad()
                
                value = value_bet(state)
                loss = torch.mean((value - reward)**2)

                loss.backward()
                optimizer_value_bet.step()

                loss_value_bet += loss.detach()

        """
    

        """         
        VALUE --- counter update
        """
        dataset = DataLoader(TensorDataset(state_counter, reward_counter), batch_size = batch_size)

        for i in range(epochs2):

            loss_value_counter = 0

            for state, reward in dataset:

                optimizer_value_counter.zero_grad()
                
                value = value_counter(state)
                loss = torch.mean((value - reward)**2)

                loss.backward()
                optimizer_value_counter.step()

                loss_value_counter += loss.detach()

        """
        VALUE -- action update
        """
        dataset = DataLoader(TensorDataset(state_action, reward_action), batch_size = batch_size)

        for i in range(epochs2):

            loss_value_game = 0

            for state, reward in dataset:

                optimizer_value_game.zero_grad()
                
                value = value_game(state)
                loss = torch.mean((value - reward)**2)
                
                loss.backward()
                optimizer_value_game.step()

                loss_value_game += loss.detach()

        """
        PRINT BUDGET
        """
        print('=== Epoch', k, '| Bank:', budget,'===')
                
            

    return hist_counter

In [50]:
value_bet = ValueBet()
policy_bet = PolicyBet()

value_game = ValueGame()
policy_game = PolicyGame()

value_counter = ValueCounter()
policy_counter = PolicyCounter()

hc = train(epochs = 2,
    value_bet = value_bet,
    policy_bet = policy_bet,
    value_counter = value_counter,
    policy_counter = policy_counter,
    value_game = value_game,
    policy_game = policy_game,
          verbose = False)

=== Epoch 0 | Bank: -4805.0 ===
=== Epoch 1 | Bank: -5740.0 ===
=== Epoch 2 | Bank: -3640.0 ===
=== Epoch 3 | Bank: -4645.0 ===
=== Epoch 4 | Bank: -3570.0 ===
=== Epoch 5 | Bank: -4740.0 ===
=== Epoch 6 | Bank: -4655.0 ===
=== Epoch 7 | Bank: -3705.0 ===
=== Epoch 8 | Bank: -4635.0 ===
=== Epoch 9 | Bank: -5315.0 ===
