# Optimal Blackjack with Reinforcement Learning 

You will enjoy this notebook if you:
- Are interested in learning more about the card game Blackjack
- Would like to see reinforcement learning applied to simply but highly stochastic games
- Are interested in Monte Carlo methods and how they are used to evaluate and improve policies

Existing Reinforcement Learning (RL) environments often miss nuances of the real game of Blackjack. Cards not being replaced and drawing based on inflexible probabilities will inevitably introduce inacuracies and 

### Why a Custom Environment?
To better account for the complexities of real-world Blackjack, we've built an environment that incorporates card counting. This enhancement over standard environments brings us closer to the dynamic nature of the game, setting the stage for more effective RL strategies.

Features:
- **Card Dealing with Memory**: Unlike other environments where each card drawn is an independent event with equal probabilities, our environment emulates a real-life card dealing scenario. Once a card is drawn, it is removed from the deck until reshuffling, significantly affecting the probabilities of future card draws.
- **Redeal Mechanics**: Passing `redeal=True` into our `reset()` function provides an alternate to the conventional `reset()` method. Instead of reshuffling and redealing from a full deck, it continues dealing from the same deck until a certain proportion of cards `(1-shuffle_frequency)` have been dealt. This simulates realistic gameplay where games are often played with the same deck until a significant number of cards have been dealt.
- **Flexible Observation Output**: Our environment is versatile, capable of outputting both conventional and comprehensive observations. It can provide the standard `(player_total, dealer_card, usable_ace)` tuple or a more detailed observation pattern that includes the remaining deck: `(player_hand, dealer_card, remaining_cards)`.
- **Beginning From a Specified State**: We can specify a state in the `reset()` method to restart the hand to, even when redealing from a depleted deck. This allows us to explore certain states without having to randomly encounter them.

In [1]:
import random

class Blackjack:
    '''
    A class representing a Blackjack environment for reinforcement learning.
 
    The environment follows OpenAI's Gym API conventions, allowing for easy interaction and 
    integration with reinforcement learning algorithms.
    '''
    def __init__(self, stake=10, n_decks=4, shuffle_frequency=0.5, simple_observations=True):
        # Instantiate instance variables
        self.stake=stake
        self.n_decks = n_decks
        self.shuffle_frequency = shuffle_frequency
        self.simple_observations = simple_observations
        
        # Initialize game state
        self.deck = []
        self.player_hand = []
        self.dealer_hand = []
        self.game_over = False
        
    def reset(self, state=None, redeal=False):
        '''
        Resets the game state, shuffles the deck, and deals initial cards.
        Returns the initial observations.
        Parameters:
            state: An optional parameter representing a new game state. 
                We can optionally choose the inital state of the game, cards will be removed from the deck
                and the dealer will be given one specified card and one random hole card.
            redeal: An optional parameter, if True will deal following hands from the same deck without shuffling.
                The deck will be shuffled until shuffle_frequency is the proportion remaining.
        '''
        self.game_over = False
        
        if redeal:
            # Shuffle deck and initialize player hands if there are few enough cards left
            proportion_of_deck_left = len(self.deck) / (self.n_decks * 52)
            if proportion_of_deck_left <= self.shuffle_frequency:
                self.deck = self.shuffle_deck()
        else:
            self.deck = self.shuffle_deck()
            
        self.player_hand = []
        self.dealer_hand = []
        
        if state: # If we're reading in a specific state to use:
            player_total = state[0]
            dealer_card = state[1]
            usable_ace = state[2]
            
            if dealer_card not in self.deck:
                raise Exception("Dealer upcard not in deck.")
            self.deal_card(self.dealer_hand, dealer_card)
            self.deal_card(self.dealer_hand)
            
            
            if usable_ace:
                player_card1 = 11
                player_card2 = player_total - 11
                if player_card2 == 1: # Bug fixing the possible case that player is dealt two aces. 
                    player_card2 = 11
            else:
                choice = random.choice(self.deck)
                attempts = 1 # Error handling if there aren't any eligible cards in the deck.
                while (choice == 11) or (player_total - choice < 2) or (player_total - choice > 10) or ((player_total - choice) not in self.deck): 
                    choice = random.choice(self.deck) # This may be less optimal but it's most accurate
                    attempts += 1
                    if attempts > 1000:
                        raise Exception("Player hand not possible with current deck.")
                player_card1 = choice
                player_card2 = player_total - player_card1
            
            self.deal_card(self.player_hand, player_card1)
            self.deal_card(self.player_hand, player_card2)
            
            
        else:
            for _ in range(2): # Deal 2 random cards to the player and dealer
                self.deal_card(self.player_hand)
                self.deal_card(self.dealer_hand)
            
        if self.calculate_hand(self.player_hand) == 21 or self.calculate_hand(self.dealer_hand) == 21:
            self.game_over = True
            
        return self.get_observations()
    
    def shuffle_deck(self):
        '''
        Returns a shuffled deck with the specified number of decks.
        '''
        deck = [2,3,4,5,6,7,8,9,10,10,10,10,11] * 4 * self.n_decks
        random.shuffle(deck)
        return deck
    
    def deal_card(self, hand, card=None):
        '''
        Deals a card from the deck to a specified hand.
        '''
        if card:
            card_index = self.deck.index(card)
            hand.append(self.deck.pop(card_index))
            
        else:
            hand.append(self.deck.pop())
   
    def stick(self):
        '''
        Ends the player's turn without drawing any more cards.
        '''
        self.game_over = True
        
    def hit(self):
        '''
        Deals another card to the player's hand. If the total exceeds 21, the game ends.
        '''
        self.deal_card(self.player_hand)
        if self.calculate_hand(self.player_hand) > 21:
            self.game_over = True
    
    def calculate_hand(self, hand):
        '''
        Calculates the total points in a hand.
        If the hand's total is over 21 and contains an ace, the value of that ace
        is reduced to 1.
        '''
        total = sum(hand)
        # Handle "soft" aces (value 11)
        while total > 21 and 11 in hand:
            for i in range(len(hand)):
                if hand[i] == 11:
                    hand[i] = 1
                    break # In case of more than 1 ace, don't set both to 1.
            total = sum(hand)
        return total
    
    def resolve_dealer_hand(self):
        '''
        Draws cards for the dealer until they reach at least 17 points.
        '''
        while self.calculate_hand(self.dealer_hand) < 17:
            self.deal_card(self.dealer_hand)
            
        
    def action(self, action):
        '''
        Handles an action taken by the player. It throws an error if the game is over.
        Returns observations, reward, done, info.
        '''
        if self.is_done():
            raise Exception("Game is over, actions are not possible.")
    
        if action == 0:
            self.stick()
            self.resolve_dealer_hand()
        if action == 1:
            self.hit()
            
        observations = self.get_observations()
        reward = self.calculate_reward()
        done = self.is_done()
        info = self.get_info()
        
        return observations, reward, done, info
    
    def get_observations(self):
        '''
        Returns the current game state. If the game is over, all of the dealer's cards are revealed.
        '''
        def format_deck(deck):
            '''
            Reformats a long list of cards into a more readable nested tuple format.
            '''
            counter_dict = {i: 0 for i in range(2,12)}
            for card in deck:
                counter_dict[card] += 1

            return tuple((card, count) for card, count in counter_dict.items())

        if self.simple_observations:
            obs1 = self.calculate_hand(self.player_hand)
            obs2 = self.dealer_hand[1] if not self.game_over else self.calculate_hand(self.dealer_hand)
            obs3 = 11 in self.player_hand
        else:
            obs1 = tuple(self.player_hand)
            obs2 = tuple(self.dealer_hand[1:]) if not self.game_over else tuple(self.dealer_hand)
            deck = self.deck.copy()
            if not self.game_over:
                deck.append(self.dealer_hand[0])
            obs3 = format_deck(deck)

        return obs1, obs2, obs3
        
    def get_actions(self):
        '''
        Returns the available actions for the player. 
        '''
        # If game is over, no actions can be performed
        if self.game_over:
            return []
        else:
            # 0 stands for 'stick', 1 stands for 'hit'
            return [0, 1]

    def calculate_reward(self):
        '''
        Calculates the reward after the game ends. The reward depends on the result of the game.
        If the game has not yet finished, this method will return 0.
        '''
        if self.is_done() == False:
            return 0
        player_n_cards = len(self.player_hand)
        player_total = self.calculate_hand(self.player_hand)
        dealer_total = self.calculate_hand(self.dealer_hand)
        
        if player_total == dealer_total:
            return 0 # Push 
        elif player_total == 21 and player_n_cards == 2:
            return self.stake * 1.5 # Dealt Blackjack
        elif player_total > 21:
            return -self.stake # Player bust
        elif dealer_total > 21:
            return self.stake # Dealer bust
        elif player_total < dealer_total:
            return -self.stake # Dealer has a higher total
        elif player_total > dealer_total:
            return self.stake # Player has a higher total
        
    def is_done(self):
        '''
        Checks if the game is over.
        '''
        return self.game_over
    
    def get_info(self):
        '''
        Returns a string of the game's current status.
        This method is used to make interpreting game states more simple.
        '''
        player_n_cards = len(self.player_hand)
        player_total = self.calculate_hand(self.player_hand)
        dealer_total = self.calculate_hand(self.dealer_hand)
        
        if player_total == 21 and dealer_total == 21 and player_n_cards == 2:
            return "Player and dealer both dealt 21."
        elif player_total == 21 and player_n_cards == 2:
            return "Player dealt Blackjack."
        elif dealer_total == 21 and player_n_cards == 2:
            return "Dealer dealt 21."
        
        if self.is_done():
            if player_total > 21:
                return "Player busted."
            elif dealer_total > 21:
                return "Dealer busted."
            elif player_total > dealer_total:
                return f"Player's {player_total} beats dealer's {dealer_total}."
            elif player_total < dealer_total:
                return f"Player's {player_total} loses to dealer's {dealer_total}."
            elif player_total == dealer_total:
                return "Push."
        else:
            return f"Player has {sum(self.player_hand)} vs dealer's {self.dealer_hand[1]} and x."

## Agents and Policies
We're about to introduce a straightforward agent whose strategy for playing Blackjack adheres to a simple policy:

- The agent "sticks" if the total value of the cards held is 20 or higher.
- The agent "hits" if the total value is less than 20.

Let's delve into this policy and explore its implications in the game:

In [2]:
from collections import defaultdict

class SimpleAgent:
    def __init__(self):
        self.reward = 0
        self.policy = defaultdict(int)  # Initialize policy dictionary
        
        # Populate policy dictionary with all possible states for plotting
        for player_total in range(4, 22):
            for dealer_upcard in range(2, 12):
                for usable_ace in [False, True]:
                    if player_total <= 19:
                        self.policy[(player_total, dealer_upcard, usable_ace)] = 1
                    else:
                        self.policy[(player_total, dealer_upcard, usable_ace)] = 0
                        
    def reset(self):
        self.reward = 0

    def step(self, env: Blackjack):
        # Get observations from environment to make decision with (discarded with random agent)
        observations = env.get_observations()

        # Get available actions and choose a random one
        actions = env.get_actions()

        # Use policy dictionary to make decision
        player_total, dealer_upcard, usable_ace = observations
        action = self.policy[(player_total, dealer_upcard, usable_ace)]

        obs, reward, done, info = env.action(action)
        self.reward += reward
        
import pandas as pd
import numpy as np
import plotly.graph_objects as go

def plot_policy(policy, usable_ace=False, policy_name=""):
    # Convert policy dictionary to DataFrame
    data = pd.DataFrame([(obs[0], obs[1], obs[2], action) 
                         for obs, action in policy.items()],
                        columns=["player_total", "dealer_card", "usable_ace", "action"])

    # Filter data based on usability of ace
    data = data[data['usable_ace'] == usable_ace]
    
    # Create grid of player totals and dealer cards
    player_totals = np.arange(4, 22)
    dealer_cards = np.arange(2, 12)

    # Create 2D grid representing the action at each state
    policy_grid = data.pivot(index='player_total', columns='dealer_card', values='action').reindex(index=player_totals, columns=dealer_cards)

    # Prepare hover text
    hovertext = np.vectorize(lambda x, y, z: f"Player Total: {y}<br>Dealer Upcard: {x}<br>Action: {'Hit' if z==1 else 'Stick'}")(np.array(policy_grid.columns)[None, :], np.array(policy_grid.index)[:, None], policy_grid.values)

    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=policy_grid.T,  # Transpose policy grid
        x=player_totals,  # Swap x and y
        y=dealer_cards,  # Swap x and y
        hovertext=hovertext.T,  # Transpose hover text
        hoverinfo='text',  # show only custom hover text
        colorscale=['#cccccc', '#eb7632'],
        zmin=0,
        zmax=1,
        xgap=4,
        ygap=4,
        showscale=False  # hide color bar
    ))

    fig.update_layout(
        title="Blackjack Strategy " + policy_name,
        xaxis_title='Player Total',  # Swap titles
        yaxis_title='Dealer Upcard',  # Swap titles
        autosize=True,
    )

    # Show the figure
    fig.show()

# Call the function with policy dictionary
plot_policy(SimpleAgent().policy, usable_ace=False, policy_name="- Always Hit Under 20")


Our current strategy isn't exactly optimal, but what's the real impact? How much might we stand to lose if we stick to it?

One approach to answer this question is using Monte Carlo simulations.

Here's the setup: Picture us playing Blackjack for a couple of hours each night, taking on 50 hands with each hand having a $10 stake. Blackjack, by nature, has quite a bit of variance, which means there's a slim chance we could end up with more than we started, even with a less than ideal strategy.

To get a comprehensive view of this, let's run our nightly scenario 5000 times, which is like playing a round of Blackjack every night for nearly 14 years. This should provide a solid understanding of how our strategy would fare in the long run.

In [3]:
def run_simulations(agent, env, n_episodes, n_simulations):
    total_rewards = []

    for _ in range(n_simulations):
        total_reward = 0
        for _ in range(n_episodes):
            while not env.is_done():
                agent.step(env)
            total_reward += agent.reward
            agent.reset()
            env.reset()
        total_rewards.append(total_reward)

    return total_rewards

def plot_rewards_histogram(rewards, bins=10):
    # Calculate histogram data
    hist, bin_edges = np.histogram(rewards, bins=bins)

    # Calculate statistics
    mean = np.mean(rewards)
    
    # Create bar plot
    fig = go.Figure(data=go.Bar(x=bin_edges, y=hist))
    
    fig.update_layout(
        title="Total Rewards Distribution",
        xaxis_title="Reward",
        yaxis_title="Count",
    )

    # Limit x-axis range
    fig.update_xaxes(range=[-500, 500])

    # Add mean line
    fig.add_shape(
        type='line',
        x0=mean, y0=0, x1=mean, y1=max(hist),
        line=dict(color='Black', width=1.5, dash='dot')
    )
    
    # Add mean line
    fig.add_shape(
        type='line',
        x0=0, y0=0, x1=0, y1=max(hist),
        line=dict(color='#cccccc', width=0.5)
    )

    # Show the figure
    fig.show()
    
    
agent = SimpleAgent()
env = Blackjack(stake=10)
env.reset()
n_episodes = 50
n_simulations = 5000


total_rewards = run_simulations(agent, env, n_episodes, n_simulations)

plot_rewards_histogram(total_rewards, bins=15)

print(f"Mean reward of: ${(sum(total_rewards)/n_simulations):.2f}")
print(f"Profitable {(100*sum([i>0 for i in total_rewards])/n_simulations):.2f}% of the time.")

Mean reward of: $-177.97
Profitable 0.14% of the time.


#### Observations:
- The dotted line indicates the mean profit.
- On average we lose about $\$$180 a night.
- Over the course of 5000 sessions, we manage to end up ahead about 0.15% of the time.

### Finding an optimal policy
This strategy is clearly suboptimal, and since we aspire for more, we want to identify an optimal strategy for Blackjack, enabling mathematically perfect decisions. However, due to Blackjack's inherent randomness and numerous possible game states, deriving an optimal solution through traditional dynamic programming methods is quite complex. So, let's outline the approach we're going to use to uncover the optimal policy:

1. We begin with a provisional policy—say, always stick if our hand total is 20 or higher.
2. We assess this policy's performance using reinforcement learning to approximate a state-action value function, mapping each possible game state to its corresponding reward when we apply our policy.
3. We update our current policy by selecting actions that yield the highest state-action value.
4. We return to step 2 and continue iterating.


## Monte Carlo Methods
Monte Carlo methods are a class of computational algorithms that rely on repeated random sampling to make numerical estimations of unknown parameters. 
The term "Monte Carlo" is often used broadly to describe any estimation method involving significant randomness.

First, we will explore the prediction problem, where we aim to compute $v_π$ and $q_π$ for a fixed, arbitrary policy $π$. After understanding policy prediction, we will delve into policy improvement and finally address the control problem, solved via Generalized Policy Iteration (GPI).

### Monte Carlo prediction
Our initial focus will be Monte Carlo methods for estimating the state-value function (V function) for a given policy. Remember, a state's value signifies the expected return - the anticipated cumulative future discounted reward beginning from that state. A logical method to estimate this from experience is by averaging the returns observed following visits to the state. As more returns accumulate, this average should converge to the expected value. This principle underlies all Monte Carlo methods.

Consider a scenario where we want to estimate $v_π(s)$, the value of a state $s$ under policy $π$. We have a set of episodes that have been collected by following $π$ and include visits to $s$. An episode might include multiple visits to $s$. We distinguish the first occurrence of $s$ in an episode as the first visit to $s$.

The first-visit Monte Carlo (MC) method estimates $v_π(s)$ as the average of the returns following the first visits to $s$, while the every-visit MC method averages the returns after all visits to $s$. Though these two MC methods are quite similar, they exhibit slightly different theoretical properties.

Let's consider an agent following the simple stick over 20 strategy we saw earlier.

Using this simple agent and policy, we will use the first-visit Monte Carlo method to compute the state-value function. We will then visualize this state-value function with a 3D surface plot. This visual exploration will give us insights into how the policy performs in different game states.

In [4]:
def first_visit_mc(env: Blackjack, agent, n_iterations=10000):
    '''
    Performs First-Visit Monte Carlo prediction for a Blackjack environment and agent.
    
    Args:
        env: Blackjack environment in which the agent acts.
        agent: The agent that interacts with the environment.
        n_iterations: The number of episodes to generate for the evaluation.
    
    Returns:
        value_estimates: A dictionary where keys are states and values are the estimated values of those states.
        average_return_per_episode: The average return per episode over all episodes.
    '''
    
    # Initialize dictionaries for the cumulative return and the count of visits for each state
    value_sums = defaultdict(float)
    state_counts = defaultdict(float)

    # Variable to keep track of the total return from all episodes
    total_return = 0 

    for i in range(n_iterations):
        # Reset the environment and agent for the new episode
        observations = env.reset()
        agent.reset()
        
        # Generate an episode by following the agent's policy
        episode = []
        while not env.is_done():
            # Get the current action, using the agent's policy
            action = agent.policy[observations]

            # Perform the action and get the reward
            next_observations, reward, _, _ = env.action(action)
            
            # Add the observed state-action-reward tuple to the episode
            episode.append((observations, action, reward))
            
            # Move to the next state
            observations = next_observations

        # Initialize the total return for the current episode
        G = 0   
        
        # Create a set to store visited states of the current episode
        visited_states = set()  
        
        # Loop backwards through the episode
        for t in range(len(episode)-1, -1, -1):
            # Get the state, action, and reward of the current step
            obs, action, reward = episode[t]
            
            # Update the total return
            G += reward 

            # If it's the first visit of the state in the current episode,
            # update the state's cumulative return and count of visits
            if obs not in visited_states:
                value_sums[obs] += G
                state_counts[obs] += 1
                visited_states.add(obs)

        # Add the total return of the current episode to the total return from all episodes
        total_return += G

    # Compute the value estimates by averaging the cumulative returns over the number of visits for each state
    value_estimates = {state: value_sums[state] / state_counts[state]
                        for state in value_sums}
    
    # Calculate the average return per episode
    average_return_per_episode = total_return / n_iterations

    return value_estimates, average_return_per_episode


import pandas as pd
import numpy as np
import plotly 
plotly.offline.init_notebook_mode(connected=True)
import plotly.graph_objects as go


def plot_value_function(value_estimates, usable_ace=False):
    
    colorscale = [[0, '#0C3584'],   # Blue
                  [0.3, '#4394A9'], # Teal
                  [0.5, '#BCBCBC'], # Light gray
                  [0.7, '#F29B39'], # Orange
                  [1, '#DA201F']]   # Red
    
    # Convert the dictionary to a pandas DataFrame
    data = pd.DataFrame([(obs[0], obs[1], obs[2], value) 
                        for obs, value in value_estimates.items()],
                        columns=["player_total", "dealer_card", "usable_ace", "value"])

    # Create a grid of player totals and dealer cards
    player_totals = np.arange(4, 22)
    dealer_cards = np.arange(2, 12)

    # Create 3D surfaces for the cases with and without a usable ace
    fig = go.Figure()

    # Filter data based on the usability of the ace and pivot to create a 2D array of values
    filtered_data = data[data['usable_ace'] == usable_ace]
    z_values = filtered_data.pivot(index='player_total', columns='dealer_card', values='value').reindex(index=player_totals, columns=dealer_cards)

    fig.add_trace(go.Surface(x=dealer_cards, y=player_totals, z=z_values, 
                         showscale=True, colorscale=colorscale, opacity=0.85, cmid=0,
                         contours = {
                             "x": {"show": True, "start": dealer_cards[0], "end": dealer_cards[-1], "size": 1, "color":"rgba(200, 200, 200, 0.95)"},
                             "y": {"show": True, "start": player_totals[0], "end": player_totals[-1], "size": 1, "color":"rgba(200, 200, 200, 0.95)"},
                             "z": {"show": False, "start": z_values.min().min(), "end": z_values.max().max(), "size": 1, "color":"white"}
                         }))
    fig.update_layout(
    title="State Value Estimates",
    autosize=False,
    width=700,
    height=600,
    scene=dict(
        xaxis_title='Dealer Card',
        yaxis_title='Player Total',
        zaxis_title='Value',
        aspectratio=dict(x=1, y=1, z=0.3),
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.4, y=-0.75, z=0.4)
        )
    )
)

    # Show the figure
    fig.show()
    
    
# Call the function
env = Blackjack(stake=1, n_decks=4)
agent_simple = SimpleAgent()
value_estimates_simple_agent, avg_return = first_visit_mc(env, agent_simple, n_iterations=10000)

plot_value_function(value_estimates_simple_agent, usable_ace=False)

#### Observations:
- Since we're evaluating our policy using simulations, we need to choose the right number of simulations to use. Too few simulations means we're not going to evaluate rare states enough. Think about how often we're going to evaluate the state `player_total=4`, `dealer_card=2` for example.
- This surface shows many peaks and troughs in areas that should be fairly flat, so let's increase the number of simulations to smooth that out somewhat.

In [5]:
env = Blackjack(stake=1, n_decks=4)
agent_simple = SimpleAgent()
value_estimates_simple_agent, avg_return = first_visit_mc(env, agent_simple, n_iterations=200000)

plot_value_function(value_estimates_simple_agent, usable_ace=False)

#### Observations:
- We can see our V function has flattened out significantly, indicating that our state-value function estimation has converged far closer to the true value.

Now that we're able to evaluate policies, let's move on to how we can improve a policy using *Monte Carlo control*.

## The Theory of Monte Carlo Control
Monte Carlo control methods estimate the value of a policy by taking averages from actual returns instead of using the Bellman equation, which is used in dynamic programming methods. Monte Carlo methods require only experience-sample sequences of states, actions, and rewards from actual or simulated interaction with an environment.

The core idea of Monte Carlo control is to use Monte Carlo methods to estimate the action-value function (Q function) and then refine the policy based on these estimates. One of the main advantages of Monte Carlo methods is that they do not require a model of the environment, in contrast to dynamic programming methods. They can learn using sample episodes alone.

A breakdown of Monte Carlo control methods is as follows:
- Initialize a starting policy and an arbitrary Q function.
- Generate an episode following the current policy.
- For each state-action pair in the episode:
- Compute the return (total discounted reward from this point onwards).
- Update the Q value of this state-action pair based on this calculated return.
- Improve the policy based on new action-value function estimates.

In [6]:
from collections import defaultdict

class MonteCarloAgent:
    def __init__(self, env, epsilon=0.25):
        '''
        Initialize a new agent.

        Parameters:
            env: The environment in which the agent acts.
            epsilon: The probability with which the agent will take a random action (exploration). 
                     The remaining (1-epsilon) probability will be to choose the best known action (exploitation).

        '''
        self.reward = 0 
        self.epsilon = epsilon
        self.action_space_size = 2 # Number of actions, in this case just two: hit and stick.
        
        # State-Action pair Function: Each state-action pair maps to a vector of Q-values.
        # For each possible action, there is a Q-value, which is an estimate of the total reward 
        # the agent can achieve starting from the given state and taking the given action.
        self.Q_table = defaultdict(lambda: np.zeros(self.action_space_size))
        
        # Cumulative sum of rewards for state-action pairs.
        self.returns_sum = defaultdict(float)
        
        # Number of times each state-action pair has been visited.
        self.returns_count = defaultdict(float)
        
        # The policy: Mapping from each state to an action.
        self.policy = defaultdict(int)  # Initialize policy dictionary
        
        # Using the simple stick on 20 strategy as an initial policy.
        for player_total in range(4, 22):
            for dealer_upcard in range(2, 12):
                for usable_ace in [False, True]:
                    if player_total < 20:
                        self.policy[(player_total, dealer_upcard, usable_ace)] = 1

    def reset(self):
        '''
        Resets the cumulative reward.
        '''
        self.reward = 0
        
    def get_probs(self, Q_s):
        '''
        This function implements an epsilon-greedy policy for a given state.

        Parameters:
        - Q_s: A numpy array containing the Q-values of all possible actions in the given state.

        Returns:
        - policy_s: A numpy array representing a probability distribution over all possible actions in the given state. 
                     The action with the highest Q-value has a probability of 1 - epsilon / action_space_size, 
                     and all other actions have a probability of epsilon / action_space_size.
        '''
        # Initialize an array of equal probabilities for each possible action.
        # Initially, set all actions to have a probability of epsilon / action_space_size.
        # This represents the 'exploration' part of the policy where we choose a random action.
        policy_s = np.ones(self.action_space_size) * self.epsilon / self.action_space_size
        
        # Find the action that has the highest Q-value in the given state.
        best_a = np.argmax(Q_s)
        
        # Update the probability of the best action to be 1 - epsilon (the 'exploitation' part where we choose the best known action),
        # plus epsilon / action_space_size (the small chance of selecting the best action randomly during the 'exploration' part).
        policy_s[best_a] = 1 - self.epsilon + (self.epsilon / self.action_space_size)
        return policy_s
        
    def generate_episode(self, env):
        '''
        Generate an episode following the epsilon-greedy policy.

        Parameters:
            env: The environment in which the agent acts.

        Returns:
            episode: List of tuples, where each tuple is a (state, action, reward) triplet.
            initial_reward: Reward from the initial state, before any actions were taken.

        '''
        state = env.reset()
        initial_reward, done = env.calculate_reward(), env.is_done()

        episode = []

        # If the game ends before the agent can take an action, return the reward but do not generate an episode
        if done:
            return episode, initial_reward

        while not done:
            probs = self.get_probs(self.Q_table[state])
            action = np.random.choice(np.arange(2), p=probs)
            next_state, reward, done, _ = env.action(action)
            episode.append((state, action, reward))
            state = next_state

        return episode, 0
    
    def learn(self, env, num_episodes, discount_factor=1.0):
        '''
        Learn from multiple episodes.

        Parameters:
            env: The environment in which the agent acts.
            num_episodes: Number of episodes from which to learn.
            discount_factor: The factor by which to discount future rewards.

        '''
        for i_episode in range(1, num_episodes + 1):
            episode, initial_reward = self.generate_episode(env)
            self.reward += initial_reward
            
            # Check if episode is empty and skip to next iteration if so.
            if not episode:
                continue
            
            states, actions, rewards = zip(*episode)
            discounts = np.array([discount_factor**i for i in range(len(rewards)+1)])
            for i, state in enumerate(states):
                self.returns_sum[(state, actions[i])] += sum(rewards[i:]*discounts[:-(1+i)])
                self.returns_count[(state, actions[i])] += 1.0
                self.Q_table[state][actions[i]] = self.returns_sum[(state, actions[i])] / self.returns_count[(state, actions[i])]
            for state in self.Q_table:
                self.policy[state] = np.argmax(self.Q_table[state])
    
    def update_policy(self, state):
        '''
        Update the policy for a specific state.

        Parameters:
            state: The state for which to update the policy.

        '''
        self.policy[state] = np.argmax(self.Q_table[state])
        
    def step(self, env: Blackjack):
        '''
        Take one step in the environment following the policy.

        Parameters:
            env: The environment in which the agent acts.
        '''
        # Get current state
        state = env.get_observations()
        # Choose an action based on the current policy
        action = self.policy[state]
        # Take action and get reward
        next_state, reward, done, _ = env.action(action)
        # Update cumulative reward
        self.reward += reward
        
env = Blackjack(n_decks=3)
agent_monte_carlo = MonteCarloAgent(env)
agent_monte_carlo.learn(env, num_episodes=100000)

value_estimates_monte_carlo_agent, avg_reward = first_visit_mc(env, agent_monte_carlo, n_iterations=100000)
plot_value_function(value_estimates_monte_carlo_agent, usable_ace=False)
plot_policy(agent_monte_carlo.policy, usable_ace=False, policy_name="- Blackjack Optimal Strategy (Partially Converged)")

env.reset()
n_episodes = 50
n_simulations = 5000
total_rewards = run_simulations(agent_monte_carlo, env, n_episodes, n_simulations)
plot_rewards_histogram(total_rewards, bins=15)
print(f"Mean reward of: ${(sum(total_rewards)/n_simulations):.2f}")
print(f"Profitable {(100*sum([i>0 for i in total_rewards])/n_simulations):.2f}% of the time.")

Mean reward of: $-26.74
Profitable 31.42% of the time.


In theory, provided our agent has a non-zero exploratory rate, this agent will necessarily converge to the optimal strategy given an infinite amount of time.

This is a good example where theory in reinforcement learning doesn't perfectly intersect with reality.

We could let this agent derive a policy approaching optimal if we let it run for hours and hours, but let's try and determine how we could potentially increase convergence rate. There are some issues with finding optimal policies in Blackjack however:

**Issue 1**:
- Blackjack is a very high variance game. Imagine we're trying to determine the Q-value for the state `(12, 6, False)`. In this situation there's roughly a 50% chance of the player busting if they hit and a 50% chance the dealer will bust accounting for all possible hole cards. This spot is going to be very high variance, and imagine during training we get unlucky playing hit and bust 5 times in a row, creating a very high negative average reward. With this happening we've created a situation where this action is rarely going to be selected, and when the difference between hit and stick is going to be only a percent or two, it's going to be very difficult for the state-action values to converge to their true values. This could ultimately take a very very long time for high variance spots.

**Issue 2**:
- Some spots are very rare and are not encountered often in the game, for example, the state `(4, 2, False)`, which we will encounter in approximately $(\frac{1}{52})^3$ = 7e-6 = once in every 140,000 games, the frequency of which can be observed below. In Blackjack these rare spots also happen to be very high variance as both the player and dealer may be hitting several times. The high variance coupled with the rarity of states like this mean that our Monte Carlo control method is unlikely to find accurate Q-values for the state-action pairs in this state.

**Issue 3**:
- Another issue caused by high variance is that it is difficult to evaluate states that are rare and high variance sufficiently using Monte Carlo first visit. Since in high variance spots where the player or dealer has a roughly 50% chance of busting, we're going to need to simulate those spots thousands of times to get a reasonably accurate estimate, but as an example, we're only reaching the state `(4, 2, False)` once in roughly every 140,000 games. This is going to mean that as policies get closer and closer to optimal, we're going to need to evaluate them with infeasibly large number of simulations to compare their performances.

In [7]:
poor_choice_found = False
state = (0,0,False)

for i in range(11, 4, -1): # It makes no sense to do anything other than hit in these spots.
    if poor_choice_found:
        break
    for j in range(2, 12):
        if np.argmax(agent_monte_carlo.Q_table[(i,j,False)]) == 0:
            state = (i, j, False)
            poor_choice_found = True
            break
            
print("Suboptimal choice found at:", state)
print(f"Q-values for this suboptimal state: {agent_monte_carlo.Q_table[state]}")
print(f"Suboptimal state visited with stick: {agent_monte_carlo.returns_count[(state,0)]} times.")
print(f"Suboptimal state visited with hit: {agent_monte_carlo.returns_count[(state,1)]} times.")

Suboptimal choice found at: (8, 4, False)
Q-values for this suboptimal state: [-1.93717277 -2.72727273]
Suboptimal state visited with stick: 191.0 times.
Suboptimal state visited with hit: 33.0 times.


#### Observations:
- It's possible we could address issues like this by increasing the exploration rate, however since we're using an on-policy method, doing so would also decrease the accuracy of our Q-values. This would fix itself with enough iterations, but this isn't a satisfactory solution.

What we're going to do is change 2 things:

1. **Prioritized Experience Replay**: We aim to tackle high-variance states by introducing a prioritization scheme in the state initialization. We'll identify "high-priority" states as those with the highest variance and therefore require a substantial number of episodes for convergence.

2. **Stochastic Initialization**: We propose a stochastic environment initialization strategy, controlled by a random parameter $\theta \in [0,1]$. A uniform random number will be generated for each episode. If the number is less than $\theta$, we'll use a standard random initialization, favoring frequent states. If the number is greater than $\theta$, we'll initialize the game in a pre-selected high-priority state.

#### Calculating the priority
The steps to calculate the "priority" of a state and initialize the game accordingly are as follows:

a. Compute the probability of the dealer busting for each possible upcard, denoted by $p_{dealer bust}(upcard)$.

b. Compute the probability of the player busting for each possible hand total, denoted $p_{player bust}(total)$. Note that for total <= 11, $p_{player bust}(total) = 0$ as the player can't bust on the next hit.

c. Assign a priority score to each state. The priority score reflects the variance or learning potential of the state, calculated as $priority(state) = 0.5 - abs(0.5 - p_{dealer bust}(upcard)) + 0.5 - abs(0.5 - p_{player bust}(total))$. The priority score is set to 0 for states where the player has a hand total <= 11, as the optimal strategy is to always hit in these states.

d. Using this, rather than randomly initializing the environment, we select initial states based on priority. We can convert the priority scores into a probability distribution and randomly select an initial state, weighted by its priority.

e. With the initial state selected, the rest of the episode is simulated using standard Monte Carlo control.

This strategy seeks to balance the need to sample all states for unbiased learning with the need to focus on high-variance states for efficient learning. The exact choice of $\theta$ and the epsilon decay schedule should be tuned based on empirical performance.

### Calculating bust probabilities
Here we're going to calculate and plot the probabilities of busting for the player and dealer. Calculations assume 4 decks, and dealer sticking on a soft 17.

In [8]:
def calculate_dealer_bust(n_trials=10000, n_decks=4):
    
    probabilities = {i : 0 for i in range(2, 12)}
    busts = {i : 0 for i in range(2, 12)}
    visits = {i : 0 for i in range(2, 12)}
    
    def reset_deck():
        deck = [2,3,4,5,6,7,8,9,10,10,10,10,11] * 4 * n_decks
        random.shuffle(deck)
        return deck
        
    def draw_card(deck):
        card = random.choice(deck)
        return deck.pop(deck.index(card))
    
    def calculate_total(hand):
        if sum(hand) > 21 and 11 in hand:
            hand[hand.index(11)] = 1
        return sum(hand)
    
    for _ in range(n_trials):
        deck = reset_deck()

        hand = []
        start_card = draw_card(deck)
        hand.append(start_card)
        hand.append(draw_card(deck))
        
        while calculate_total(hand) < 17:
            hand.append(draw_card(deck))
            
        if calculate_total(hand) > 21:
            bust = 1
        else:
            bust = 0
        
        visits[start_card] += 1
        busts[start_card] += bust
        probabilities[start_card] = busts[start_card]/visits[start_card]

    return probabilities
    
def calculate_player_bust(n_trials=10000, n_decks=4):
    
    probabilities = {(i,t) : 0 for i in range(12, 22) for t in (True,False)}
    busts = {(i,t) : 0 for i in range(12, 22) for t in (True,False)}
    visits = {(i,t) : 0 for i in range(12, 22) for t in (True,False)}
    
    def reset_deck():
        deck = [2,3,4,5,6,7,8,9,10,10,10,10,11] * 4 * n_decks
        random.shuffle(deck)
        return deck
        
    def draw_card(deck):
        card = random.choice(deck)
        return deck.pop(deck.index(card))
    
    def calculate_total(hand):
        if sum(hand) > 21 and 11 in hand:
            hand[hand.index(11)] = 1
        return sum(hand)
    
    for _ in range(n_trials):
        deck = reset_deck()

        hand = []
        hand.append(draw_card(deck))
        hand.append(draw_card(deck))
        total = calculate_total(hand)
        
        while total <= 11:
            hand.append(draw_card(deck))
            total = calculate_total(hand)
            
        start = total
        ace = True if 11 in hand else False
        
        hand.append(draw_card(deck))
            
        if calculate_total(hand) > 21:
            bust = 1
        else:
            bust = 0
        
        visits[(start, ace)] += 1
        busts[(start, ace)] += bust
        probabilities[(start, ace)] = busts[(start, ace)]/visits[(start, ace)]

    return probabilities

def plot_dealer_bust_probabilities(prob_dict, title='Dealer Card Bust Probabilities'):
    keys = list(prob_dict.keys())
    values = list(prob_dict.values())

    fig = go.Figure(data=go.Heatmap(
        z=[values], 
        x=keys, 
        y=[''], 
        colorscale='Reds',
        zmax=1,
        showscale=False,
        hovertemplate = 'Dealer Upcard: %{x}<br>Bust Probability: %{z}<extra></extra>',
    ))

    # Adding annotations
    for i, val in enumerate(values):
        fig.add_annotation(
            x=keys[i], 
            y='', 
            text="{:.2f}".format(val),
            showarrow=False,
            font=dict(
                color="black",
                size=12
            )
        )

    fig.update_layout(
        title=title,
        xaxis=dict(title='Dealer Upcard'),
        yaxis=dict(title='', showticklabels=False),
        width=800,
        height=250
    )
    
    fig.show()
    
    
def plot_player_bust_probabilities(prob_dict, title='Player Bust Probabilities'):
    # Create DataFrame from prob_dict
    df = pd.DataFrame(
        [(k[0], k[1], v) for k, v in prob_dict.items()],
        columns=['Player Hand', 'Usable Ace', 'Bust Probability']
    )
    df['Usable Ace'] = df['Usable Ace'].map({True: 'Usable Ace', False: 'No Usable Ace'})  # convert boolean to string
    
    fig = go.Figure(data=go.Heatmap(
        z=df['Bust Probability'],
        x=df['Player Hand'],
        y=df['Usable Ace'],
        colorscale='Reds',
        zmax=1,
        showscale=False,
        hovertemplate='Player Hand: %{x}<br>Usable Ace: %{y}<br>Bust Probability: %{z}<extra></extra>',
    ))

    # Adding annotations
    for i, row in df.iterrows():
        fig.add_annotation(
            x=row['Player Hand'], 
            y=row['Usable Ace'], 
            text="{:.2f}".format(row['Bust Probability']),
            showarrow=False,
            font=dict(color="black", size=12)
        )

    fig.update_layout(
        title=title,
        xaxis=dict(title='Player Hand'),
        yaxis=dict(title='', showticklabels=True),
        width=800,
        height=300
    )
    
    fig.show()
    
dealer_bust_probabilities = calculate_dealer_bust(n_trials=100000)
player_bust_probabilities = calculate_player_bust(n_trials=100000)

plot_dealer_bust_probabilities(dealer_bust_probabilities)
plot_player_bust_probabilities(player_bust_probabilities)

### Calculating hand priorities
Now let's initialize the priorities dictionary, with a value for each state, then calculate the priority value for each state. States with higher variance get a higher priority. We will then pass this computed dictionary to our Monte Carlo agent and include the four improvements we mentioned prior.

In [9]:
def calculate_priority(player_total, dealer_upcard, usable_ace):
    if player_total < 12 or player_total > 19:
        return 0 
    else:
        dealer_variance = 0.5 - abs(0.5 - dealer_bust_probabilities[dealer_upcard])
        player_variance = 0.5 - abs(0.5 - player_bust_probabilities[(player_total, usable_ace)])
        return dealer_variance + player_variance
        
priorities ={(i,j,k) : 0 for i in range(4,22) for j in range (2, 12) for k in (True, False)}

for i,j,k in priorities.keys():
    if i < 12:
        priorities[i,j,k] = 0
    elif i == 21:
        priorities[i,j,k] = 0
    else:
        priorities[i,j,k] = calculate_priority(i,j,k)
        

class MonteCarloAgent:
    def __init__(self, env, priorities, epsilon=0.25, theta=0.5):
        '''
        Initialize a new agent.

        Parameters:
            env: The environment in which the agent acts.
            epsilon: The probability with which the agent will take a random action (exploration). 
                     The remaining (1-epsilon) probability will be to choose the best known action (exploitation).

        '''
        self.reward = 0 
        self.priorities = priorities
        self.epsilon = epsilon
        self.original_epsilon = epsilon
        self.theta = theta
        self.action_space_size = 2 # Number of actions, in this case just two: hit and stick.
        
        # State-Action pair Function: Each state-action pair maps to a vector of Q-values.
        # For each possible action, there is a Q-value, which is an estimate of the total reward 
        # the agent can achieve starting from the given state and taking the given action.
        self.Q_table = defaultdict(lambda: np.zeros(self.action_space_size))
        
        # Cumulative sum of rewards for state-action pairs.
        self.returns_sum = defaultdict(float)
        
        # Number of times each state-action pair has been visited.
        self.returns_count = defaultdict(float)
        
        # The policy: Mapping from each state to an action.
        self.policy = defaultdict(int)  # Initialize policy dictionary
        
        # Use our simple stick on 20 strategy as the initial policy
        for player_total in range(4, 22):
            for dealer_upcard in range(2, 12):
                for usable_ace in [False, True]:
                    if player_total < 20:
                        self.policy[(player_total, dealer_upcard, usable_ace)] = 1
        
    def reset(self):
        '''
        Resets the cumulative reward.
        '''
        self.reward = 0
        
    def get_probs(self, Q_s):
        '''
        This function implements an epsilon-greedy policy for a given state.

        Parameters:
        - Q_s: A numpy array containing the Q-values of all possible actions in the given state.

        Returns:
        - policy_s: A numpy array representing a probability distribution over all possible actions in the given state. 
                     The action with the highest Q-value has a probability of 1 - epsilon / action_space_size, 
                     and all other actions have a probability of epsilon / action_space_size.
        '''
        # Initialize an array of equal probabilities for each possible action.
        # Initially, set all actions to have a probability of epsilon / action_space_size.
        # This represents the 'exploration' part of the policy where we choose a random action.
        policy_s = np.ones(self.action_space_size) * self.epsilon / self.action_space_size
        
        # Find the action that has the highest Q-value in the given state.
        best_a = np.argmax(Q_s)
        
        # Update the probability of the best action to be 1 - epsilon (the 'exploitation' part where we choose the best known action),
        # plus epsilon / action_space_size (the small chance of selecting the best action randomly during the 'exploration' part).
        policy_s[best_a] = 1 - self.epsilon + (self.epsilon / self.action_space_size)
        return policy_s
        
        
    def sample_priority(self):
        keys = list(self.priorities.keys())
        values = list(self.priorities.values())

        weighted_random_key = random.choices(keys, weights=values, k=1)[0]
        return weighted_random_key


    def stochastic_priority_sample(self):
        if random.random() > self.theta: # If greater than theta use a weighted priority sample
            initial_state = self.sample_priority()
            state = env.reset(initial_state)
            return state
        else:                            # Else just initialize a state normally
            state = env.reset()
            return state

    def generate_episode(self, env):
        '''
        Generate an episode following the epsilon-greedy policy.

        Parameters:
            env: The environment in which the agent acts.

        Returns:
            episode: List of tuples, where each tuple is a (state, action, reward) triplet.
            initial_reward: Reward from the initial state, before any actions were taken.

        '''
        state = self.stochastic_priority_sample()
        initial_reward, done = env.calculate_reward(), env.is_done()

        episode = []

        # If the game ends before the agent can take an action, return the reward but do not generate an episode
        if done:
            return episode, initial_reward

        while not done:
            probs = self.get_probs(self.Q_table[state])
            action = np.random.choice([0,1], p=probs)

            next_state, reward, done, _ = env.action(action)
            episode.append((state, action, reward))
            state = next_state

        return episode, 0
    
    def learn(self, env, num_episodes, discount_factor=1.0):
        '''
        Learn from multiple episodes.

        Parameters:
            env: The environment in which the agent acts.
            num_episodes: Number of episodes from which to learn.
            discount_factor: The factor by which to discount future rewards.

        '''
        for i_episode in range(1, num_episodes + 1):
            episode, initial_reward = self.generate_episode(env)
            self.reward += initial_reward
            
            # Check if episode is empty and skip to next iteration if so.
            if not episode:
                continue
            
            states, actions, rewards = zip(*episode)
            discounts = np.array([discount_factor**i for i in range(len(rewards)+1)])
            for i, state in enumerate(states):
                self.returns_sum[(state, actions[i])] += sum(rewards[i:]*discounts[:-(1+i)])
                self.returns_count[(state, actions[i])] += 1.0
                self.Q_table[state][actions[i]] = self.returns_sum[(state, actions[i])] / self.returns_count[(state, actions[i])]
            for state in self.Q_table:
                self.policy[state] = np.argmax(self.Q_table[state])
    
    def update_policy(self, state):
        '''
        Update the policy for a specific state.

        Parameters:
            state: The state for which to update the policy.

        '''
        self.policy[state] = np.argmax(self.Q_table[state])
        
        
    def step(self, env: Blackjack):
        '''
        Take one step in the environment following the policy.

        Parameters:
            env: The environment in which the agent acts.
        '''
        # Get current state
        state = env.get_observations()
        # Choose an action based on the current policy
        action = self.policy[state]
        # Take action and get reward
        next_state, reward, done, _ = env.action(action)
        # Update cumulative reward
        self.reward += reward
    
        
env = Blackjack(n_decks=4)
agent_monte_carlo = MonteCarloAgent(env, priorities, epsilon=0.25, theta=0.5)
agent_monte_carlo.learn(env, num_episodes=10000000)

value_estimates_monte_carlo_agent, avg_reward = first_visit_mc(env, agent_monte_carlo, n_iterations=1000000)
plot_value_function(value_estimates_monte_carlo_agent, usable_ace=False)
plot_policy(agent_monte_carlo.policy, usable_ace=False, policy_name="- Blackjack Optimal Strategy")

env.reset()
n_episodes = 50
n_simulations = 5000
total_rewards = run_simulations(agent_monte_carlo, env, n_episodes, n_simulations)
plot_rewards_histogram(total_rewards, bins=15)
print(f"Mean reward of: ${(sum(total_rewards)/n_simulations):.2f}")
print(f"Profitable {(100*sum([i>0 for i in total_rewards])/n_simulations):.2f}% of the time.")

Mean reward of: $-28.55
Profitable 30.42% of the time.
