In [39]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
import matplotlib.pyplot as plt


In [53]:

########################################
# Global Variables and Constants
########################################

verbose = 0

# Constants
NUM_DECKS = 2
DEFAULT_WAGER = 25
STARTING_BANKROLL = 10000
SHUFFLE_POINT = 0.25
TOTAL_CARDS = 52 * NUM_DECKS
NUM_SHOES_TO_PLAY = 100

COUNT_VALUES = {
    '2': +1, '3': +1, '4': +1, '5': +1, '6': +1,
    '7': 0, '8': 0, '9': 0,
    '10': -1, 'J': -1, 'Q': -1, 'K': -1, 'A': -1
}

HIT = 0
STAND = 1
DOUBLE = 2
SPLIT = 3

running_count = 0  
shoes_played = 0  
hands_played = 0
wins = 0
losses = 0
pushes = 0

def create_deck():
    suits = ['Hearts', 'Diamonds', 'Clubs', 'Spades']
    values = ['2', '3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A']
    single_deck = [(value, suit) for suit in suits for value in values]
    return single_deck * NUM_DECKS

def shuffle_deck(deck):
    random.shuffle(deck)
    return deck

def get_card_value(card):
    card_rank = card[0]
    if card_rank in ['J','Q','K','10']:
        return 10
    elif card_rank == 'A':
        return 11
    else:
        return int(card_rank)

def can_split(hand):
    if len(hand) == 2:
        card_values = {'2': 2, '3': 3, '4':4, '5':5,'6':6,'7':7,'8':8,'9':9,'10':10,'J':10,'Q':10,'K':10,'A':11}
        return card_values[hand[0][0]]==card_values[hand[1][0]]
    return False

def calculate_hand_value(hand):
    card_values = {
        '2':2,'3':3,'4':4,'5':5,'6':6,'7':7,
        '8':8,'9':9,'10':10,'J':10,'Q':10,'K':10
    }
    value = 0
    aces_used_as_11 = 0
    for card, suit in hand:
        if card=='A':
            value+=11
            aces_used_as_11+=1
        else:
            value+=card_values[card]
    while value>21 and aces_used_as_11>0:
        value-=10
        aces_used_as_11-=1
    is_soft=(aces_used_as_11>0)
    return value,is_soft

def is_blackjack(hand):
    v,_=calculate_hand_value(hand)
    return len(hand)==2 and v==21

def is_pair(hand):
    return len(hand)==2 and hand[0][0]==hand[1][0]

def dealer_upcard_value(card):
    if card[0] in ['J','Q','K','10']:
        return 10
    elif card[0]=='A':
        return 11
    else:
        return int(card[0])

def initialize_bankroll(initial_amount=STARTING_BANKROLL):
    return initial_amount

def place_wager(bankroll, wager_amount=DEFAULT_WAGER):
    if wager_amount>bankroll:
        raise ValueError("Wager exceeds available bankroll.")
    return bankroll-wager_amount, wager_amount

def update_bankroll(bankroll, wager, outcome):
    if outcome=='win':
        return bankroll+(wager*2)
    elif outcome=='push':
        return bankroll+wager
    elif outcome=='lose':
        return bankroll
    else:
        raise ValueError("Invalid outcome provided.")

def determine_winner(player_value, dealer_value):
    if verbose: print("\n--- Final Results ---")
    if player_value>21:
        if verbose: print("You busted. Dealer wins.")
    elif dealer_value>21:
        if verbose: print("Dealer busted. You win!")
    elif player_value>dealer_value:
        if verbose: print("You win!")
    elif player_value<dealer_value:
        if verbose: print("Dealer wins.")
    else:
        if verbose: print("It's a push!")

def deal_card(deck):
    global running_count
    if len(deck) <= int(TOTAL_CARDS * SHUFFLE_POINT):
        if verbose: print("Cut card reached! Reshuffle after this round.")
    card = deck.pop()
    running_count += COUNT_VALUES.get(card[0],0)
    return card

def get_player_decision(hand, dealer_card, bankroll, current_wager, can_double, can_split_hand, splits_done, max_splits=3):
    player_value, is_soft = calculate_hand_value(hand)
    if player_value == 21:
        return 'stand'
    dealer_value = dealer_upcard_value(dealer_card)

    def pair_val(hand):
        return get_card_value(hand[0])

    if is_pair(hand):
        rank = hand[0][0]
        val = pair_val(hand)
        if rank=='A':
            if can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                if dealer_value in [5,6] and can_double and bankroll>=current_wager:
                    return 'double'
                return 'hit'
        if val==10:
            return 'stand'
        if val==9:
            if dealer_value in [2,3,4,5,6,8,9] and can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                return 'stand'
        if val==8:
            if can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                if dealer_value in [2,3,4,5,6]:
                    return 'stand'
                else:
                    return 'hit'
        if val==7:
            if dealer_value in [2,3,4,5,6,7] and can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                return 'hit'
        if val==6:
            if dealer_value in [2,3,4,5,6] and can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                return 'hit'
        if val==5:
            if dealer_value in range(2,10) and can_double and bankroll>=current_wager:
                return 'double'
            else:
                return 'hit'
        if val==4:
            if dealer_value in [5,6] and can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                return 'hit'
        if val==3:
            if dealer_value in range(2,8) and can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                return 'hit'
        if val==2:
            if dealer_value in range(2,8) and can_split_hand and splits_done<max_splits and bankroll>=current_wager:
                return 'split'
            else:
                return 'hit'

    if is_soft:
        if player_value==20: return 'stand'
        if player_value==19:
            if dealer_value==6 and can_double and bankroll>=current_wager: return 'double'
            else: return 'stand'
        if player_value==18:
            if dealer_value in range(2,7) and can_double and bankroll>=current_wager: return 'double'
            elif dealer_value in [9,10,11]: return 'hit'
            else: return 'stand'
        if player_value==17:
            if dealer_value in [3,4,5,6] and can_double and bankroll>=current_wager: return 'double'
            else: return 'hit'
        if player_value==16:
            if dealer_value in [4,5,6] and can_double and bankroll>=current_wager: return 'double'
            else: return 'hit'
        if player_value==15:
            if dealer_value in [4,5,6] and can_double and bankroll>=current_wager: return 'double'
            else: return 'hit'
        if player_value==14:
            if dealer_value in [5,6] and can_double and bankroll>=current_wager: return 'double'
            else: return 'hit'
        if player_value==13:
            if dealer_value in [5,6] and can_double and bankroll>=current_wager: return 'double'
            else: return 'hit'

    if not is_soft:
        if player_value>=17:
            return 'stand'
        if player_value==16:
            if dealer_value in [2,3,4,5,6]: return 'stand'
            else: return 'hit'
        if player_value==15:
            if dealer_value in [2,3,4,5,6]: return 'stand'
            else: return 'hit'
        if player_value==14:
            if dealer_value in [2,3,4,5,6]: return 'stand'
            else: return 'hit'
        if player_value==13:
            if dealer_value in [2,3,4,5,6]: return 'stand'
            else: return 'hit'
        if player_value==12:
            if dealer_value in [4,5,6]: return 'stand'
            else: return 'hit'
        if player_value==11:
            if can_double and bankroll>=current_wager:
                return 'double'
            else:
                return 'hit'
        if player_value==10:
            if dealer_value in range(2,10) and can_double and bankroll>=current_wager:
                return 'double'
            else:
                return 'hit'
        if player_value==9:
            if dealer_value in [3,4,5,6] and can_double and bankroll>=current_wager:
                return 'double'
            else:
                return 'hit'
        if player_value<=8:
            return 'hit'

    return 'hit'

def simulate_basic_strategy_episodes(n_episodes=5000):
    global running_count, shoes_played
    deck = shuffle_deck(create_deck())
    bankroll = STARTING_BANKROLL

    obs_data=[]
    act_data=[]

    for _ in range(n_episodes):
        if bankroll <= 0:
            break
        try:
            bankroll, wager = place_wager(bankroll)
        except ValueError:
            break

        player_hand, dealer_hand = [deal_card(deck), deal_card(deck)], [deal_card(deck), deal_card(deck)]
        player_val, is_soft = calculate_hand_value(player_hand)
        d_up = dealer_hand[0]
        d_val = 11 if d_up[0]=='A' else (10 if d_up[0] in ['10','J','Q','K'] else int(d_up[0]))
        initial_flag = 1
        decks_remaining = (len(deck)/52) if (len(deck)/52)>0 else 0.0001
        true_count_val = running_count / decks_remaining

        obs = np.array([player_val, 1 if is_soft else 0, d_val, initial_flag, true_count_val], dtype=np.float32)

        if is_blackjack(dealer_hand):
            if is_blackjack(player_hand):
                outcome='push'
            else:
                outcome='lose'
            bankroll = update_bankroll(bankroll, wager, outcome)
            if len(deck) <= int(TOTAL_CARDS * SHUFFLE_POINT):
                deck[:] = shuffle_deck(create_deck())
                running_count=0
                shoes_played+=1
            continue

        if is_blackjack(player_hand):
            winnings=(wager*1.5)
            bankroll+=winnings+wager
            if len(deck) <= int(TOTAL_CARDS * SHUFFLE_POINT):
                deck[:] = shuffle_deck(create_deck())
                running_count=0
                shoes_played+=1
            continue

        player_hands=[player_hand]
        wagers=[wager]
        current_hand_index=0
        splits_done=0

        while current_hand_index<len(player_hands):
            hand=player_hands[current_hand_index]
            if len(hand)==1:
                new_card=deal_card(deck)
                hand.append(new_card)

            while True:
                p_val, is_soft = calculate_hand_value(hand)
                d_up = dealer_hand[0]
                d_val = 11 if d_up[0]=='A' else (10 if d_up[0] in ['10','J','Q','K'] else int(d_up[0]))
                initial_flag = 1 if len(hand)==2 else 0
                decks_remaining = len(deck)/52 if (len(deck)/52)>0 else 0.0001
                true_count_val = running_count/decks_remaining
                obs = np.array([p_val, 1 if is_soft else 0, d_val, initial_flag, true_count_val],dtype=np.float32)

                can_double=(len(hand)==2)
                can_split_hand=can_split(hand)

                action_str=get_player_decision(hand,d_up,bankroll,wagers[current_hand_index],can_double,can_split_hand,splits_done,3)
                action_map={'hit':0,'stand':1,'double':2,'split':3}
                action_idx=action_map[action_str]

                obs_data.append(obs)
                act_data.append(action_idx)

                if action_str=='stand':
                    break
                elif action_str=='hit':
                    new_card=deal_card(deck)
                    hand.append(new_card)
                    p_val,_=calculate_hand_value(hand)
                    if p_val>21:
                        break
                elif action_str=='double':
                    bankroll-=wagers[current_hand_index]
                    wagers[current_hand_index]*=2
                    new_card=deal_card(deck)
                    hand.append(new_card)
                    break
                elif action_str=='split':
                    if can_split_hand and splits_done<3 and bankroll>=wagers[current_hand_index]:
                        bankroll-=wagers[current_hand_index]
                        c1=hand[0]
                        c2=hand[1]
                        new_hand_1=[c1]
                        new_hand_2=[c2]
                        player_hands[current_hand_index]=new_hand_1
                        player_hands.insert(current_hand_index+1,new_hand_2)
                        wagers.append(wagers[current_hand_index])
                        splits_done+=1
                        new_card=deal_card(deck)
                        new_hand_1.append(new_card)
                    else:
                        # invalid split attempt, just hit
                        new_card=deal_card(deck)
                        hand.append(new_card)
                        p_val,_=calculate_hand_value(hand)
                        if p_val>21:
                            break

                p_val,_=calculate_hand_value(hand)
                if p_val>21:
                    break
                if action_str in ['stand','double']:
                    break
            current_hand_index+=1

        # Dealer turn
        if any(calculate_hand_value(h)[0]<=21 for h in player_hands):
            while True:
                d_val,d_soft=calculate_hand_value(dealer_hand)
                if d_val>17 or (d_val==17 and not d_soft):
                    break
                if d_val==17 and d_soft:
                    dealer_hand.append(deal_card(deck))
                elif d_val<17:
                    dealer_hand.append(deal_card(deck))
                else:
                    break
        else:
            d_val,_=calculate_hand_value(dealer_hand)

        d_val,_=calculate_hand_value(dealer_hand)
        for i,hand_ in enumerate(player_hands):
            p_val,_=calculate_hand_value(hand_)
            if p_val>21:
                outcome='lose'
            elif d_val>21:
                outcome='win'
            elif p_val>d_val:
                outcome='win'
            elif p_val<d_val:
                outcome='lose'
            else:
                outcome='push'

            if outcome=='win':
                bankroll+=(wagers[i]*2)
            elif outcome=='push':
                bankroll+=wagers[i]

        if len(deck) <= int(TOTAL_CARDS * SHUFFLE_POINT):
            deck[:] = shuffle_deck(create_deck())
            running_count=0
            shoes_played+=1

    return np.array(obs_data), np.array(act_data)

class BCPolicy(nn.Module):
    def __init__(self, input_dim=5, output_dim=4):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(input_dim,64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,output_dim)
        )
    def forward(self,x):
        return self.net(x)

def train_behavior_cloning(obs,acts,epochs=10,batch_size=64):
    model=BCPolicy()
    optimizer=optim.Adam(model.parameters(),lr=1e-3)
    criterion=nn.CrossEntropyLoss()
    dataset=torch.utils.data.TensorDataset(torch.tensor(obs,dtype=torch.float32),torch.tensor(acts,dtype=torch.long))
    loader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
    for _ in range(epochs):
        for xb,yb in loader:
            optimizer.zero_grad()
            pred=model(xb)
            loss=criterion(pred,yb)
            loss.backward()
            optimizer.step()
    return model

class BlackjackEnv(gym.Env):
    metadata = {"render_modes":[]}

    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Box(low=-10, high=50, shape=(5,), dtype=np.float32)
        self.action_space = spaces.Discrete(4)

        self.deck = shuffle_deck(create_deck())
        self.running_count = 0
        self.shoes_played = 0
        self.bankroll = STARTING_BANKROLL
        self.hands_played = 0
        self.wins = 0
        self.losses = 0
        self.pushes = 0

        self.current_hands = []
        self.dealer_hand = []
        self.done = False
        self.splits_done = 0

    def _deal_card(self):
        if len(self.deck) <= int(TOTAL_CARDS * SHUFFLE_POINT):
            if verbose: print("Cut card reached! Reshuffle after this round.")
        card = self.deck.pop()
        self.running_count += COUNT_VALUES.get(card[0],0)
        return card

    def _deal_initial_hands(self):
        player_hand = [self._deal_card(), self._deal_card()]
        dealer_hand = [self._deal_card(), self._deal_card()]
        return player_hand, dealer_hand
    
def reset(self, *, seed=None, options=None):
    super().reset(seed=seed)
    self.deck = shuffle_deck(create_deck())
    self.running_count = 0
    self.splits_done = 0
    self.current_hands = []
    self.dealer_hand = []
    self.done = False
    self.bankroll = STARTING_BANKROLL  # <-- Ensure we reset bankroll here

    # Deal initial hands
    player_hand, self.dealer_hand = self._deal_initial_hands()

    # Subtract initial wager
    self.bankroll -= DEFAULT_WAGER

    hand_data = {
        'hand': player_hand,
        'wager': DEFAULT_WAGER,
        'done': False,
        'initial': True,
        'blackjack': is_blackjack(player_hand)
    }

    self.current_hands = [hand_data]
    return self._get_obs(), {}

    def _true_count(self):
        decks_remaining = len(self.deck)/52
        if decks_remaining==0: decks_remaining=0.0001
        return self.running_count / decks_remaining

    def _get_obs(self):
        hand = self._get_current_hand()
        if hand is None:
            # If no current hand, just return zeros
            return np.zeros((5,), dtype=np.float32)
        p_val, is_soft = calculate_hand_value(hand['hand'])
        d_up = self.dealer_hand[0]
        d_val = 11 if d_up[0]=='A' else (10 if d_up[0] in ['10','J','Q','K'] else int(d_up[0]))
        initial_flag = 1 if hand['initial'] else 0
        obs = np.array([p_val, 1 if is_soft else 0, d_val, initial_flag, self._true_count()], dtype=np.float32)
        return obs


    def _get_current_hand(self):
        for h in self.current_hands:
            if not h['done']:
                return h
        return None

    
def step(self, action):
    hand = self._get_current_hand()
    if hand is None:
        return self._get_obs(), 0.0, True, False, {}

    can_double = (len(hand['hand']) == 2)
    can_split_hand = can_split(hand['hand'])
    if action == DOUBLE and not can_double:
        return self._invalid_action()
    if action == SPLIT and not (can_split_hand and hand['initial']):
        return self._invalid_action()

    if action == STAND:
        hand['done'] = True
    elif action == HIT:
        new_card = self._deal_card()
        hand['hand'].append(new_card)
        hand['initial'] = False
        h_val, _ = calculate_hand_value(hand['hand'])
        if h_val > 21:
            hand['done'] = True
    elif action == DOUBLE:
        # Deduct an additional wager for doubling before increasing the wager
        self.bankroll -= hand['wager']
        hand['wager'] *= 2
        new_card = self._deal_card()
        hand['hand'].append(new_card)
        hand['done'] = True
    elif action == SPLIT:
        self.splits_done += 1
        c1, c2 = hand['hand']
        new_hand_1 = [c1]
        new_hand_2 = [c2]
        idx = self.current_hands.index(hand)

        # Deduct another wager for the new split hand
        self.bankroll -= hand['wager']

        self.current_hands[idx] = new_hand_1 = {
            'hand': new_hand_1,
            'wager': hand['wager'],
            'done': False,
            'initial': True,
            'blackjack': False
        }
        self.current_hands.insert(idx+1, {
            'hand': new_hand_2,
            'wager': hand['wager'],
            'done': False,
            'initial': True,
            'blackjack': False
        })
        new_card = self._deal_card()
        new_hand_1['hand'].append(new_card)
        new_hand_1['initial'] = False

    if all(h['done'] for h in self.current_hands):
        return self._dealers_turn()
    else:
        return self._get_obs(), 0.0, False, False, {}

    
    def _invalid_action(self):
        for h in self.current_hands:
            h['done']=True
        obs,rew,done,trunc,info = self._end_round()
        return obs, -1.0, True, trunc, info

    def _dealers_turn(self):
        if any(calculate_hand_value(h['hand'])[0]<=21 for h in self.current_hands):
            while True:
                d_val, d_soft = calculate_hand_value(self.dealer_hand)
                if d_val>17 or (d_val==17 and not d_soft):
                    break
                if d_val==17 and d_soft:
                    self.dealer_hand.append(self._deal_card())
                elif d_val<17:
                    self.dealer_hand.append(self._deal_card())
                else:
                    break
        return self._end_round()
    
def _end_round(self):
    d_val, _ = calculate_hand_value(self.dealer_hand)
    reward = 0.0
    for h in self.current_hands:
        p_val, _ = calculate_hand_value(h['hand'])
        if p_val > 21:
            outcome = 'lose'
        elif d_val > 21:
            outcome = 'win'
        elif p_val > d_val:
            outcome = 'win'
        elif p_val < d_val:
            outcome = 'lose'
        else:
            outcome = 'push'

        if outcome == 'win':
            reward += 1
            self.wins += 1
            # Normal win: get 2x wager total (wager + winnings)
            self.bankroll += (h['wager'] * 2)
            # If blackjack, pay an additional 0.5 * wager for a total of 2.5 * wager returned (3:2 payout)
            if h.get('blackjack', False):
                self.bankroll += (h['wager'] * 0.5)
        elif outcome == 'lose':
            reward -= 1
            self.losses += 1
            # Wager already lost at start and/or additional wagers at double/split
        else:  # push
            self.pushes += 1
            self.bankroll += h['wager']

        self.hands_played += 1

    self.done = True
    self._maybe_reshuffle()
    obs = self._get_obs()
    info = {'wins': self.wins, 'losses': self.losses, 'pushes': self.pushes, 'hands_played': self.hands_played}
    return obs, float(reward), True, False, info

    def _maybe_reshuffle(self):
        if len(self.deck)<=int(TOTAL_CARDS*SHUFFLE_POINT):
            if verbose: print("Reshuffling deck now...")
            self.deck[:]=shuffle_deck(create_deck())
            self.running_count=0
            self.shoes_played+=1
            if verbose: print(f"*** Completed Shoe {self.shoes_played} ***")

    def render(self):
        pass
    

def print_strategy_chart(model, soft_flag=0, initial_flag=1, tc=0):
    action_map = {0:'Hit', 1:'Stand', 2:'Double', 3:'Split'}
    dealer_upcards = range(2,12)
    player_totals = range(4,22)

    print(f"\nStrategy Chart (soft={soft_flag}, initial={initial_flag}, tc={tc})")
    print("Dealer: ", " ".join(str(d) for d in dealer_upcards))

    for p_val in player_totals:
        row_actions = []
        for d_val in dealer_upcards:
            obs = np.array([p_val, soft_flag, d_val, initial_flag, tc], dtype=np.float32)
            action, _ = model.predict(obs, deterministic=True)
            action = int(action)  # Convert from np.array to int if needed
            row_actions.append(action_map[action])
        print(f"Player {p_val}: {' '.join(row_actions)}")


def evaluate_model(model, eval_episodes=10000):
    test_env = RLBlackjackEnv()

    # Reset stats
    test_env.wins = 0
    test_env.losses = 0
    test_env.pushes = 0
    test_env.hands_played = 0
    test_env.bankroll = STARTING_BANKROLL

    outcomes = []
    bankroll_history = []
    total_reward = 0.0

    for _ in range(eval_episodes):
        obs, info = test_env.reset()
        initial_bankroll = test_env.bankroll
        done = False
        episode_reward = 0.0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, truncated, info = test_env.step(action)
            episode_reward += reward
        total_reward += episode_reward

        # net outcome for the hand
        hand_outcome = test_env.bankroll - initial_bankroll
        outcomes.append(hand_outcome)
        bankroll_history.append(test_env.bankroll)

    # Compute EV and variance
    EV = sum(outcomes)/len(outcomes)
    E_X2 = sum(x*x for x in outcomes)/len(outcomes)
    variance = E_X2 - EV*EV

    wins = test_env.wins
    losses = test_env.losses
    pushes = test_env.pushes
    hands = test_env.hands_played
    final_bankroll = test_env.bankroll
    net_profit = final_bankroll - STARTING_BANKROLL

    print(f"\nEvaluation over {eval_episodes} episodes:")
    print(f"Hands Played: {hands}")
    print(f"Wins: {wins}, Losses: {losses}, Pushes: {pushes}")
    print(f"Win rate: {wins/hands*100:.2f}%")
    print(f"Average reward per hand: {total_reward/hands:.3f}")
    print(f"Final Bankroll: {final_bankroll}")
    print(f"Net Profit/Loss: {net_profit}")
    print(f"Average profit per hand: {net_profit/hands:.3f}")
    print(f"EV per hand: {EV:.3f}")
    print(f"Variance per hand: {variance:.3f}")

    # Plot bankroll over time (if desired)
    plt.figure(figsize=(10,5))
    plt.plot(bankroll_history)
    plt.xlabel('Hand Number')
    plt.ylabel('Bankroll')
    plt.title('Bankroll Over Time')
    plt.grid(True)
    plt.show()


if __name__=="__main__":
    # Generate data for BC training
    obs_data, act_data = simulate_basic_strategy_episodes(n_episodes=10000)
    bc_model = train_behavior_cloning(obs_data, act_data, epochs=5)

    def make_env():
        return RLBlackjackEnv()

    env = DummyVecEnv([make_env])

    model = PPO("MlpPolicy", env, verbose=0)

    # Load BC weights into PPO policy
    with torch.no_grad():
        bc_w1=bc_model.net[0].weight
        bc_b1=bc_model.net[0].bias
        bc_w2=bc_model.net[2].weight
        bc_b2=bc_model.net[2].bias
        bc_w3=bc_model.net[4].weight
        bc_b3=bc_model.net[4].bias

        pi_state_dict = model.policy.mlp_extractor.policy_net.state_dict()
        pi_state_dict['0.weight'].copy_(bc_w1)
        pi_state_dict['0.bias'].copy_(bc_b1)
        pi_state_dict['2.weight'].copy_(bc_w2)
        pi_state_dict['2.bias'].copy_(bc_b2)
        model.policy.mlp_extractor.policy_net.load_state_dict(pi_state_dict)

        model.policy.action_net.weight.copy_(bc_w3)
        model.policy.action_net.bias.copy_(bc_b3)

    # Train PPO
    model.learn(total_timesteps=100000)

    # Evaluate model
    evaluate_model(model, eval_episodes=10000)

    # Print strategy charts for different counts
    counts_to_check = [0, 2, 5]
    for tc in counts_to_check:
        print_strategy_chart(model, soft_flag=0, initial_flag=1, tc=tc)  # Hard totals
        print_strategy_chart(model, soft_flag=1, initial_flag=1, tc=tc)  # Soft 


KeyboardInterrupt: 