<div style="background: linear-gradient(90deg, #17a2b8 0%, #0e5a63 60%, #0a3d44 100%); color: white; padding: 18px 25px; margin-bottom: 20px;">
    <div style="display: flex; justify-content: space-between; align-items: baseline;">
        <h1 style="font-family: 'Helvetica Neue', sans-serif; font-size: 24px; margin: 0; font-weight: 300;">
            Lab 5-1: Blackjack with Monte Carlo ES
        </h1>
        <span style="font-size: 11px; opacity: 0.9;">© Prof. Dehghani</span>
    </div>
    <p style="font-size: 13px; margin-top: 6px; margin-bottom: 0; opacity: 0.9;">
        IE 7295 Reinforcement Learning | Sutton and Barto Chapter 5 Figure 5.2 | 75 minutes
    </p>
</div>

<div style="background: white; padding: 15px 20px; margin-bottom: 12px; border-left: 3px solid #17a2b8;">
    <h3 style="color: #17a2b8; font-size: 14px; margin: 0 0 8px 0;">Background</h3>
    <p style="color: #555; line-height: 1.6; margin: 0; font-size: 13px;">
        This lab implements <strong>Monte Carlo ES (Exploring Starts)</strong> exactly as described in Sutton and Barto Figure 5.2. 
        The algorithm finds the optimal Blackjack policy without requiring a model of the environment. The key innovation is 
        <strong>Exploring Starts</strong>: each episode begins with a randomly selected state-action pair, guaranteeing that all 
        state-action pairs are visited infinitely often. After the initial random action, the agent follows its current greedy policy. 
        This combination ensures both exploration (through random starts) and exploitation (through greedy policy), leading to 
        convergence to the optimal policy.
    </p>
</div>

<table style="width: 100%; border-spacing: 12px;">
<tr>
<td style="background: white; padding: 12px 15px; border-top: 3px solid #17a2b8; width: 50%;">
    <h4 style="color: #17a2b8; font-size: 13px; margin: 0 0 8px 0;">Learning Objectives</h4>
    <ul style="color: #555; line-height: 1.4; margin: 0; padding-left: 18px; font-size: 12px;">
        <li>Implement Monte Carlo ES from Figure 5.2</li>
        <li>Understand exploring starts mechanism</li>
        <li>Apply first-visit MC to action-value estimation</li>
        <li>Implement greedy policy improvement</li>
        <li>Reproduce textbook Blackjack results</li>
    </ul>
</td>
<td style="background: white; padding: 12px 15px; border-top: 3px solid #00acc1; width: 50%;">
    <h4 style="color: #00acc1; font-size: 13px; margin: 0 0 8px 0;">Blackjack Rules</h4>
    <div style="color: #555; font-size: 12px; line-height: 1.6;">
        <div style="padding: 2px 0;">Goal: Sum close to 21 without exceeding</div>
        <div style="padding: 2px 0;">Actions: 0=Stick (stop), 1=Hit (draw)</div>
        <div style="padding: 2px 0;">State: (player_sum, dealer_card, usable_ace)</div>
        <div style="padding: 2px 0;">Rewards: +1 win, 0 draw, -1 lose</div>
    </div>
</td>
</tr>
</table>

---
<div style="border-left: 4px solid #17a2b8; padding-left: 12px; margin: 20px 0;">
  <h2 style="color: #17a2b8; margin: 0; font-size: 18px;">Section 1: Environment Setup and Dependencies</h2>
</div>

We begin by importing necessary libraries and initializing the Blackjack environment. The key libraries are:
- **Gymnasium**: Provides the Blackjack-v1 environment with proper episode handling
- **NumPy**: For numerical computations and array operations
- **Matplotlib**: For creating 3D value function plots and 2D policy heatmaps
- **defaultdict**: For efficient sparse storage of Q-values and returns

In [None]:
"""
Cell 1: Import Libraries and Initialize Environment

Purpose:
  - Import required libraries for MC ES implementation
  - Load pretty_print utility for formatted output
  - Configure matplotlib for visualization
  - Create Blackjack environment

Key Components:
  - gymnasium: Modern RL environment library (v1 API)
  - defaultdict: Efficient storage for sparse Q-values
  - matplotlib: 3D surface plots and 2D heatmaps
"""

import sys
import gymnasium as gym
import numpy as np
from collections import defaultdict
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
import warnings
warnings.filterwarnings('ignore')

# Configure matplotlib for publication-quality figures
plt.rcParams['figure.dpi'] = 100
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# Load pretty_print utility from GitHub
try:
    import requests
    url = 'https://raw.githubusercontent.com/mdehghani86/RL_labs/master/utility/rl_utility.py'
    response = requests.get(url)
    exec(response.text)
    pretty_print("Environment Ready", 
                 f"Gymnasium version: {gym.__version__}<br>" +
                 "Implementing Monte Carlo ES from Figure 5.2<br>" +
                 "All libraries loaded successfully", 
                 style='success')
except Exception as e:
    print(f"Libraries loaded (pretty_print unavailable: {e})")

# Create Blackjack environment (v1 uses modern API)
env = gym.make('Blackjack-v1')

pretty_print("Blackjack Environment Created",
             f"Action space: {env.action_space.n} actions<br>" +
             "0 = Stick (stop drawing), 1 = Hit (draw card)<br>" +
             "State: (player_sum, dealer_showing, usable_ace)<br>" +
             "Terminal rewards: +1 (win), 0 (draw), -1 (lose)",
             style='info')

---
<div style="border-left: 4px solid #17a2b8; padding-left: 12px; margin: 20px 0;">
  <h2 style="color: #17a2b8; margin: 0; font-size: 18px;">Section 2: Monte Carlo ES Algorithm Implementation</h2>
</div>

<div style="text-align: center; margin: 20px 0;">
    <img src="https://github.com/mdehghani86/RL_labs/blob/master/Lab%2005/MCM_ES.jpg?raw=true" 
         alt="Monte Carlo ES Pseudocode" 
         style="width: 70%; border: 2px solid #17a2b8; border-radius: 8px;">
    <p style="color: #666; font-size: 12px; margin-top: 10px;">Figure 5.2: Monte Carlo ES from Sutton and Barto</p>
</div>

The Monte Carlo ES algorithm consists of three key components:

1. **Episode Generation with Exploring Starts**: Each episode begins with a random action (exploring start), then follows the current greedy policy for all subsequent actions. This ensures comprehensive exploration.

2. **First-Visit MC Update**: For each state-action pair visited in an episode, we calculate the return from the first visit and update Q-values as a running average.

3. **Greedy Policy Improvement**: After updating Q-values, we immediately improve the policy by making it greedy with respect to the new Q-values: π(s) = argmax_a Q(s,a).

In [None]:
"""
Cell 2: Episode Generation with Exploring Starts

Purpose:
  - Generate episodes using exploring starts mechanism
  - First action: RANDOM (ensures exploration)
  - Subsequent actions: GREEDY (follows current policy)

Algorithm:
  1. Reset environment to get initial state
  2. Select RANDOM first action (exploring start)
  3. Execute first action and record (s, a, r)
  4. For rest of episode: follow greedy policy
  5. Continue until episode terminates

Parameters:
  env: Gymnasium Blackjack environment
  policy: Dictionary mapping states to actions

Returns:
  episode: List of (state, action, reward) tuples
"""

def generate_episode_with_exploring_starts(env, policy):
    episode = []
    
    # Initialize episode
    state, _ = env.reset()  # v1 returns (state, info)
    
    # CRITICAL: Exploring start - select RANDOM first action
    # This ensures all state-action pairs are explored
    action = env.action_space.sample()  # Uniform random: 0 or 1
    
    # Execute first action
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated  # Episode ends on either
    episode.append((state, action, reward))
    
    # Continue episode following GREEDY policy
    state = next_state
    while not done:
        # Follow current policy (greedy w.r.t. Q)
        # If state not in policy yet, default to random
        action = policy.get(state, env.action_space.sample())
        
        # Execute action
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        # Record step
        episode.append((state, action, reward))
        state = next_state
    
    return episode

pretty_print("Episode Generation Ready",
             "<strong>Exploring Starts Mechanism:</strong><br>" +
             "• First action: Random (exploration)<br>" +
             "• Subsequent actions: Greedy (exploitation)<br>" +
             "• Guarantees all (s,a) pairs are visited",
             style='success')

In [None]:
"""
Cell 3: Monte Carlo ES - Main Learning Algorithm

Purpose:
  - Implement complete MC ES algorithm from Figure 5.2
  - Learn optimal Q-values through episode sampling
  - Extract optimal policy via greedy improvement

Algorithm (Figure 5.2):
  1. Initialize Q(s,a) arbitrarily for all s,a
  2. Initialize Returns(s,a) as empty list for all s,a
  3. Initialize policy π arbitrarily (will become greedy)
  4. Loop for each episode:
     a) Generate episode with exploring starts
     b) For each (s,a) appearing in episode (first-visit):
        - Calculate return G from first visit
        - Append G to Returns(s,a)
        - Update Q(s,a) as average of Returns(s,a)
        - Update policy: π(s) = argmax_a Q(s,a)

Key Data Structures:
  - Q: defaultdict storing Q(s,a) estimates
  - returns: defaultdict storing list of returns for each (s,a)
  - policy: dict mapping states to greedy actions

Parameters:
  env: Blackjack environment
  num_episodes: Number of episodes to run

Returns:
  Q: Final action-value estimates
  policy: Final greedy policy
"""

def monte_carlo_es(env, num_episodes=500000):
    # Initialize Q(s,a) arbitrarily
    # defaultdict creates entries automatically with zero arrays
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    # Initialize Returns(s,a) as empty lists
    # Will store all observed returns for averaging
    returns = defaultdict(list)
    
    # Initialize policy (will become greedy)
    policy = {}
    
    pretty_print("Starting Monte Carlo ES",
                 f"Episodes: {num_episodes:,}<br>" +
                 "Method: First-Visit MC with Exploring Starts<br>" +
                 "This will take a few minutes...",
                 style='warning')
    
    # Main learning loop
    for episode_num in range(1, num_episodes + 1):
        # Generate episode using exploring starts
        episode = generate_episode_with_exploring_starts(env, policy)
        
        # Track which state-action pairs we've seen (for first-visit)
        visited_state_actions = set()
        
        # Process episode BACKWARDS to calculate returns efficiently
        # G accumulates reward from end of episode back to start
        G = 0  # Return (undiscounted, gamma=1 for Blackjack)
        
        for t in range(len(episode) - 1, -1, -1):
            state, action, reward = episode[t]
            
            # Accumulate return: G = r_t + G (since gamma=1)
            G = reward + G
            
            # Create state-action tuple for tracking
            state_action = (state, action)
            
            # First-visit check: only update if this is first occurrence
            if state_action not in visited_state_actions:
                visited_state_actions.add(state_action)
                
                # Append return to list for this state-action pair
                returns[state_action].append(G)
                
                # Update Q(s,a) as average of all observed returns
                # Q(s,a) = mean(Returns(s,a))
                Q[state][action] = np.mean(returns[state_action])
                
                # Policy improvement: make policy greedy w.r.t. Q
                # π(s) = argmax_a Q(s,a)
                policy[state] = np.argmax(Q[state])
        
        # Progress reporting
        if episode_num % 100000 == 0:
            print(f"Episode {episode_num:,}/{num_episodes:,}")
    
    pretty_print("Monte Carlo ES Complete",
                 f"Processed {num_episodes:,} episodes<br>" +
                 f"Learned Q-values for {len(Q)} states<br>" +
                 f"Policy is greedy w.r.t. learned Q-values",
                 style='success')
    
    return Q, policy

pretty_print("MC ES Algorithm Loaded",
             "Ready to learn optimal Blackjack policy<br>" +
             "Algorithm matches textbook Figure 5.2",
             style='info')

---
<div style="border-left: 4px solid #17a2b8; padding-left: 12px; margin: 20px 0;">
  <h2 style="color: #17a2b8; margin: 0; font-size: 18px;">Section 3: Visualization Functions</h2>
</div>

We create two types of visualizations:

**3D Surface Plots**: Show the optimal state-value function V*(s) = max_a Q(s,a) as a 3D surface. The x-axis is player sum, y-axis is dealer showing card, and z-axis (height/color) is the value. We create separate plots for states with and without a usable ace.

**2D Policy Heatmaps**: Display the optimal action for each state using discrete colors. Green indicates STICK (action 0) and Red indicates HIT (action 1). Using pcolormesh ensures crisp boundaries between actions with no interpolation.

In [None]:
"""
Cell 4: Create 3D Value Function Visualization

Purpose:
  - Plot optimal state-value function V*(s) = max_a Q(s,a)
  - Show how value changes with player sum and dealer card
  - Separate plots for usable/non-usable ace states

Visualization Details:
  - X-axis: Player sum (12-21)
  - Y-axis: Dealer showing (1=Ace through 10)
  - Z-axis/Color: State value (-1 to +1)
  - Blue: Low value (likely to lose)
  - Red: High value (likely to win)
"""

def plot_value_function(Q, title="Optimal Value Function"):
    def get_Z(player_sum, dealer_card, usable_ace):
        """Get optimal value V(s) = max_a Q(s,a) for state"""
        state = (player_sum, dealer_card, usable_ace)
        if state in Q:
            return np.max(Q[state])  # V*(s) = max over actions
        return 0  # Default for unvisited states
    
    def create_surface(usable_ace, ax):
        """Create 3D surface plot for given usable_ace condition"""
        # Define state space ranges
        player_range = np.arange(12, 22)  # 12 to 21
        dealer_range = np.arange(1, 11)   # 1 (Ace) to 10
        
        # Create coordinate meshgrid
        X, Y = np.meshgrid(player_range, dealer_range)
        
        # Build value array: Z[i,j] = V(player_range[j], dealer_range[i], usable_ace)
        Z = np.array([[get_Z(x, y, usable_ace) 
                      for x in player_range]  # Columns: player sums
                     for y in dealer_range])  # Rows: dealer cards
        
        # Create 3D surface
        surf = ax.plot_surface(
            X, Y, Z,
            cmap=cm.coolwarm,    # Blue (cold/bad) to Red (warm/good)
            linewidth=0,
            antialiased=True,
            vmin=-1, vmax=1,     # Value range for color mapping
            alpha=0.8
        )
        
        # Configure axes
        ax.set_xlabel('Player Sum')
        ax.set_ylabel('Dealer Showing')
        ax.set_zlabel('Value')
        ax.set_zlim(-1, 1)
        ax.view_init(elev=25, azim=-130)  # Set viewing angle
        return surf
    
    # Create figure with two subplots
    fig = plt.figure(figsize=(14, 10))
    
    # Plot 1: With usable ace
    ax1 = fig.add_subplot(211, projection='3d')
    ax1.set_title(f'{title} - Usable Ace', fontweight='bold', pad=15)
    surf1 = create_surface(True, ax1)
    fig.colorbar(surf1, ax=ax1, shrink=0.5)
    
    # Plot 2: Without usable ace
    ax2 = fig.add_subplot(212, projection='3d')
    ax2.set_title(f'{title} - No Usable Ace', fontweight='bold', pad=15)
    surf2 = create_surface(False, ax2)
    fig.colorbar(surf2, ax=ax2, shrink=0.5)
    
    plt.tight_layout()
    plt.show()

pretty_print("3D Value Visualization Ready",
             "Plots V*(s) = max_a Q(s,a) as 3D surface<br>" +
             "Color: Blue (low) to Red (high)",
             style='success')

In [None]:
"""
Cell 5: Create 2D Policy Heatmap (DISCRETE COLORS)

Purpose:
  - Visualize optimal policy π*(s) = argmax_a Q(s,a)
  - Show STICK vs HIT decisions for each state
  - Use discrete colors (no blending)

CRITICAL FIX:
  - Uses pcolormesh (NOT imshow) for discrete values
  - Ensures crisp boundaries between actions
  - No interpolation between policy decisions

Color Coding:
  - Green = STICK (action 0)
  - Red = HIT (action 1)
"""

def plot_policy(policy, title="Optimal Policy"):
    def get_action(player_sum, dealer_card, usable_ace):
        """Get optimal action for state"""
        state = (player_sum, dealer_card, usable_ace)
        return policy.get(state, 1)  # Default to HIT if not in policy
    
    def create_heatmap(usable_ace, ax):
        """Create discrete policy heatmap"""
        # Define state space
        player_range = np.arange(12, 22)  # 12-21
        dealer_range = np.arange(1, 11)   # 1-10 (Ace to 10)
        
        # Build policy grid: Z[i,j] = action
        Z = np.array([[get_action(p, d, usable_ace)
                      for p in player_range]  # Columns: player
                     for d in dealer_range])  # Rows: dealer
        
        # CRITICAL: Use pcolormesh for DISCRETE values
        # This prevents interpolation between actions
        im = ax.pcolormesh(
            player_range,           # X coordinates
            dealer_range,           # Y coordinates
            Z,                      # Action values (0 or 1)
            cmap='RdYlGn_r',        # Red=Hit(1), Green=Stick(0)
            edgecolors='black',     # Black gridlines
            linewidth=0.5,
            vmin=0, vmax=1,         # Discrete action range
            shading='flat'          # No interpolation
        )
        
        # Configure axes
        ax.set_xticks(player_range)
        ax.set_yticks(dealer_range)
        ax.set_yticklabels(['A'] + list(range(2, 11)))  # A for Ace
        ax.set_xlabel('Player Sum')
        ax.set_ylabel('Dealer Showing')
        ax.set_aspect('equal')  # Square cells
        
        # Add colorbar with discrete labels
        cbar = plt.colorbar(im, ax=ax, ticks=[0.25, 0.75], 
                           fraction=0.046, pad=0.04)
        cbar.ax.set_yticklabels(['STICK (0)', 'HIT (1)'])
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Heatmap 1: With usable ace
    ax1.set_title(f'{title} - Usable Ace', fontweight='bold')
    create_heatmap(True, ax1)
    
    # Heatmap 2: Without usable ace
    ax2.set_title(f'{title} - No Usable Ace', fontweight='bold')
    create_heatmap(False, ax2)
    
    plt.tight_layout()
    plt.show()

pretty_print("2D Policy Visualization Ready",
             "<strong>Discrete policy heatmaps:</strong><br>" +
             "• Uses pcolormesh for crisp boundaries<br>" +
             "• Green = STICK, Red = HIT<br>" +
             "• No interpolation between actions",
             style='success')

---
<div style="border-left: 4px solid #17a2b8; padding-left: 12px; margin: 20px 0;">
  <h2 style="color: #17a2b8; margin: 0; font-size: 18px;">Section 4: Run Monte Carlo ES Experiment</h2>
</div>

Now we execute the complete learning process with 500,000 episodes. This large number of episodes is necessary to ensure:
1. All state-action pairs are visited sufficiently often
2. Q-value estimates converge to true values
3. The policy converges to the optimal policy
4. Results match the textbook figures

The learning typically takes 2-3 minutes on modern hardware.

In [None]:
"""
Cell 6: Execute Monte Carlo ES Learning

Purpose:
  - Run MC ES for 500,000 episodes
  - Learn optimal Q-values and policy
  - Analyze learned policy statistics

Expected Results:
  - Policy should match textbook Figure 5.2
  - Stick more often at higher player sums (17-21)
  - Hit more often at lower player sums (12-16)
  - Different behavior with/without usable ace
"""

# Run Monte Carlo ES
Q, policy = monte_carlo_es(env, num_episodes=500000)

# Analyze learned policy
stick_count = sum(1 for action in policy.values() if action == 0)
hit_count = sum(1 for action in policy.values() if action == 1)
total_states = len(policy)

pretty_print("Learning Complete - Policy Statistics",
             f"<strong>States in learned policy:</strong> {total_states}<br><br>" +
             f"<strong>Action Distribution:</strong><br>" +
             f"• STICK (0): {stick_count} states ({100*stick_count/total_states:.1f}%)<br>" +
             f"• HIT (1): {hit_count} states ({100*hit_count/total_states:.1f}%)<br><br>" +
             "<strong>Expected Pattern:</strong><br>" +
             "Policy should stick more at high sums (20-21)<br>" +
             "Policy should hit more at low sums (12-16)",
             style='result')

In [None]:
"""
Cell 7: Visualize Optimal Value Function

Purpose:
  - Display 3D plots of learned value function
  - Show how values change across state space
  - Compare usable vs non-usable ace scenarios
"""

pretty_print("Generating 3D Value Function Plots",
             "Creating surface plots of V*(s) = max_a Q(s,a)",
             style='info')

plot_value_function(Q, "Optimal State-Value Function V*")

pretty_print("Value Function Interpretation",
             "<strong>Color Coding:</strong><br>" +
             "• Red (high): Favorable states likely to win<br>" +
             "• Blue (low): Unfavorable states likely to lose<br><br>" +
             "<strong>Key Observations:</strong><br>" +
             "• Peak values near player sum 20-21<br>" +
             "• Usable ace provides higher values (flexibility)<br>" +
             "• Values vary with dealer showing card",
             style='note')

In [None]:
"""
Cell 8: Visualize Optimal Policy

Purpose:
  - Display 2D heatmaps of learned policy
  - Show optimal STICK vs HIT decisions
  - Compare with textbook Figure 5.2
"""

pretty_print("Generating Policy Heatmaps",
             "Creating discrete policy visualizations<br>" +
             "Green = STICK, Red = HIT",
             style='info')

plot_policy(policy, "Optimal Policy π* (from MC ES)")

pretty_print("Policy Interpretation",
             "<strong>Policy Patterns:</strong><br>" +
             "• Clear boundary around sum 17-20<br>" +
             "• STICK (green) dominates at high sums<br>" +
             "• HIT (red) dominates at low sums<br>" +
             "• More aggressive with usable ace (can't bust)<br>" +
             "• Adapts to dealer showing card<br><br>" +
             "<strong>This should match Sutton and Barto Figure 5.2</strong>",
             style='note')

<div style="background: #f8f9fa; padding: 15px 20px; margin-top: 30px; border-left: 3px solid #17a2b8;">
    <h3 style="color: #17a2b8; font-size: 14px; margin: 0 0 8px 0;">Key Findings</h3>
    <p style="color: #555; line-height: 1.6; font-size: 13px;">
        <strong>1. Exploring Starts Effectiveness:</strong> Random initial actions ensured comprehensive exploration of all state-action pairs, avoiding the need for ongoing exploration like epsilon-greedy.<br><br>
        <strong>2. Policy Convergence:</strong> The greedy policy converged to the optimal policy, matching textbook results with clear decision boundaries around player sum 17-20.<br><br>
        <strong>3. Usable Ace Impact:</strong> States with usable ace show higher values and more aggressive hitting strategy due to the flexibility of ace preventing busting.<br><br>
        <strong>4. First-Visit MC:</strong> Averaging returns from first visits provided unbiased Q-value estimates that converged to true action values.<br><br>
        <strong>5. Generalized Policy Iteration:</strong> The interleaved pattern of policy evaluation (Q-value updates) and policy improvement (greedy selection) led to optimal policy.
    </p>
</div>

<div style="background: linear-gradient(90deg, #17a2b8 0%, #0e5a63 60%, #0a3d44 100%); color: white; padding: 15px 20px; margin-top: 30px; text-align: center;">
    <p style="margin: 0; font-size: 13px;">End of Lab 5-1: Monte Carlo ES</p>
    <p style="margin: 5px 0 0 0; font-size: 11px; opacity: 0.9;">Next: Lab 5-2 - Off-Policy MC with Importance Sampling</p>
</div>