<div style="background: linear-gradient(90deg, #17a2b8 0%, #0e5a63 60%, #0a3d44 100%); color: white; padding: 18px 25px;">
    <h1 style="margin: 0; font-size: 24px;">Lab 5-1: Blackjack with Monte Carlo ES</h1>
    <p style="font-size: 13px; margin: 6px 0 0 0;">IE 7295 RL | Sutton and Barto Ch 5 Figure 5.2</p>
</div>

## Section 1: Setup

Import all necessary libraries for MC ES implementation.

In [None]:
# CELL 1: Imports and Environment Setup
# This cell loads all required libraries and creates the Blackjack environment

import gymnasium as gym
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

# Load pretty_print utility
try:
    import requests
    exec(requests.get('https://raw.githubusercontent.com/mdehghani86/RL_labs/master/utility/rl_utility.py').text)
    pretty_print("Ready", "All libraries loaded", style='success')
except:
    print("Libraries loaded")

# Create Blackjack environment
env = gym.make('Blackjack-v1')
print(f"Environment created: {env}")

## Section 2: Episode Generation

**CRITICAL**: This implements EXPLORING STARTS
- First action = RANDOM (exploration)
- Rest = GREEDY (exploitation)

In [None]:
# CELL 2: Episode Generation with Exploring Starts
#
# PURPOSE: Generate episodes using the exploring starts mechanism
# 
# EXPLORING STARTS EXPLAINED:
#   - First action: RANDOM - ensures all (s,a) pairs explored
#   - Remaining actions: GREEDY - follow current policy
#   - This guarantees sufficient exploration while exploiting learned knowledge
#
# PARAMETERS:
#   env: Blackjack environment
#   policy: Current greedy policy (dict mapping states to actions)
#
# RETURNS:
#   episode: List of (state, action, reward) tuples

def generate_episode_es(env, policy):
    episode = []
    state, _ = env.reset()
    
    # EXPLORING START: Select random first action
    # This is THE KEY to exploring starts - guarantees exploration
    action = env.action_space.sample()  # Random: 0 or 1
    
    # Execute first action
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    episode.append((state, action, reward))
    
    # Follow GREEDY policy for rest of episode
    state = next_state
    while not done:
        # Get greedy action from policy (or random if state not seen yet)
        action = policy.get(state, env.action_space.sample())
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode.append((state, action, reward))
        state = next_state
    
    return episode

print("Episode generation ready")

## Section 3: MC ES Algorithm

Implements Figure 5.2 from textbook:
1. Generate episode with exploring starts
2. For each (s,a) - first visit only:
   - Calculate return G
   - Update Q(s,a) = average of returns
3. Make policy greedy: π(s) = argmax Q(s,a)

In [None]:
# CELL 3: Monte Carlo ES Main Algorithm
#
# PURPOSE: Learn optimal Q-values and policy using MC ES
#
# ALGORITHM (Figure 5.2):
#   1. Initialize Q(s,a) arbitrarily
#   2. Initialize Returns(s,a) as empty lists
#   3. Initialize policy π arbitrarily
#   4. Loop for each episode:
#      a) Generate episode with exploring starts
#      b) For each (s,a) in episode (FIRST-VISIT only):
#         - Calculate return G
#         - Append G to Returns(s,a)
#         - Q(s,a) = average(Returns(s,a))
#         - π(s) = argmax_a Q(s,a)
#
# DATA STRUCTURES:
#   Q: dict of arrays - Q[state][action] = value estimate
#   returns: dict of lists - returns[(state,action)] = list of observed returns
#   policy: dict - policy[state] = best action

def monte_carlo_es(env, num_episodes=500000):
    # Initialize Q-values (start at zero)
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    # Initialize returns storage
    returns = defaultdict(list)
    
    # Initialize policy (will become greedy)
    policy = {}
    
    print(f"Starting MC ES: {num_episodes:,} episodes")
    
    # Main learning loop
    for ep_num in range(1, num_episodes + 1):
        # Generate episode with exploring starts
        episode = generate_episode_es(env, policy)
        
        # Track visited (s,a) pairs for first-visit check
        visited = set()
        
        # Calculate returns backward through episode
        G = 0  # Return (gamma=1 for Blackjack)
        
        # Process episode BACKWARD (makes return calculation easy)
        for t in range(len(episode) - 1, -1, -1):
            state, action, reward = episode[t]
            G = reward + G  # Accumulate return
            
            sa = (state, action)
            
            # FIRST-VISIT CHECK: only update if first time seeing (s,a)
            if sa not in visited:
                visited.add(sa)
                
                # Store return
                returns[sa].append(G)
                
                # Update Q as AVERAGE of all returns
                Q[state][action] = np.mean(returns[sa])
                
                # POLICY IMPROVEMENT: Make greedy
                policy[state] = np.argmax(Q[state])
        
        # Progress
        if ep_num % 100000 == 0:
            print(f"Episode {ep_num:,}")
    
    print("MC ES complete")
    return Q, policy

print("MC ES algorithm loaded")

## Section 4: Visualization

In [None]:
# CELL 4: Visualization Functions
#
# Two visualization types:
#   1. 3D surface: V(s) = max Q(s,a)
#   2. 2D heatmap: π(s) with DISCRETE colors

def plot_value(Q):
    def get_v(ps, dc, ua):
        return np.max(Q[(ps,dc,ua)]) if (ps,dc,ua) in Q else 0
    
    def surf(ua, ax):
        pr, dr = np.arange(12,22), np.arange(1,11)
        X, Y = np.meshgrid(pr, dr)
        Z = np.array([[get_v(x,y,ua) for x in pr] for y in dr])
        s = ax.plot_surface(X,Y,Z,cmap=cm.coolwarm,vmin=-1,vmax=1,alpha=0.8)
        ax.set_xlabel('Player'); ax.set_ylabel('Dealer'); ax.set_zlabel('Value')
        ax.view_init(25,-130)
        return s
    
    fig = plt.figure(figsize=(14,10))
    ax1 = fig.add_subplot(211, projection='3d')
    ax1.set_title('With Ace')
    fig.colorbar(surf(True,ax1), ax=ax1, shrink=0.5)
    ax2 = fig.add_subplot(212, projection='3d')
    ax2.set_title('No Ace')
    fig.colorbar(surf(False,ax2), ax=ax2, shrink=0.5)
    plt.tight_layout(); plt.show()

def plot_policy(pol):
    # CRITICAL: Uses pcolormesh for DISCRETE colors (no blending)
    def get_a(ps, dc, ua):
        return pol.get((ps,dc,ua), 1)
    
    def hm(ua, ax):
        pr, dr = np.arange(12,22), np.arange(1,11)
        Z = np.array([[get_a(p,d,ua) for p in pr] for d in dr])
        # pcolormesh = discrete, no interpolation
        im = ax.pcolormesh(pr,dr,Z,cmap='RdYlGn_r',edgecolors='black',
                          linewidth=0.5,vmin=0,vmax=1,shading='flat')
        ax.set_xticks(pr); ax.set_yticks(dr)
        ax.set_yticklabels(['A']+list(range(2,11)))
        ax.set_xlabel('Player'); ax.set_ylabel('Dealer')
        ax.set_aspect('equal')
        cb = plt.colorbar(im,ax=ax,ticks=[0.25,0.75])
        cb.ax.set_yticklabels(['STICK','HIT'])
    
    fig,(ax1,ax2) = plt.subplots(1,2,figsize=(14,5))
    ax1.set_title('With Ace'); hm(True,ax1)
    ax2.set_title('No Ace'); hm(False,ax2)
    plt.tight_layout(); plt.show()

print("Viz functions ready")

## Section 5: Run Experiment

In [None]:
# CELL 5: Run MC ES
Q, policy = monte_carlo_es(env, 500000)

# Stats
stick = sum(1 for a in policy.values() if a==0)
total = len(policy)
print(f"\nPolicy: {total} states, {stick} STICK ({100*stick/total:.1f}%)")

In [None]:
# CELL 6: Visualize Results
print("Value function:")
plot_value(Q)
print("\nPolicy (should match Figure 5.2):")
plot_policy(policy)