In [1]:
def plot_results_from_json(json_file, epsilons=[0.1, 0.2, 0.3], learning_rates=[0.01, 0.001, 0.0001]):
    """
    Generate comprehensive visualization of Q-learning and Expected SARSA results.
    
    
    This function creates a grid of plots for each combination of epsilon and learning rate,
    showing the performance of both algorithms with standard deviation bands.
    
    Args:
        json_file (str): Path to the JSON results file
        epsilons (list): Exploration rates to plot (default: [0.1, 0.2, 0.3])
        learning_rates (list): Learning rates to plot (default: [0.01, 0.001, 0.0001])
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import json
    import os
    
    # Load the JSON data
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # Process each environment
    for env_name in data:
        # Process with and without replay buffer separately
        for replay_setting in ['with_replay', 'without_replay']:
            if replay_setting not in data[env_name]:
                continue
                
            # Create figure with subplots in a grid
            fig, axes = plt.subplots(len(epsilons), len(learning_rates), figsize=(15, 10), 
                                    sharex=True, sharey=True)
            
            # Handle single row/column cases
            if len(epsilons) == 1 and len(learning_rates) == 1:
                axes = np.array([[axes]])
            elif len(epsilons) == 1:
                axes = axes.reshape(1, -1)
            elif len(learning_rates) == 1:
                axes = axes.reshape(-1, 1)
            
            # Plot each combination of epsilon and learning rate
            for i, epsilon in enumerate(epsilons):
                for j, lr in enumerate(learning_rates):
                    ax = axes[i, j]
                    
                    try:
                        # Get data for this configuration
                        eps_key = f"epsilon_{epsilon}"
                        lr_key = f"step_size_{lr}"
                        config_data = data[env_name][replay_setting][eps_key][lr_key]
                        
                        # Extract data for Q-learning
                        q_mean = config_data['q_learning']['mean']
                        q_std = config_data['q_learning']['std']
                        
                        # Extract data for Expected SARSA
                        sarsa_mean = config_data['expected_sarsa']['mean']
                        sarsa_std = config_data['expected_sarsa']['std']
                        
                        # Create x-axis for episodes
                        episodes = np.arange(1, len(q_mean) + 1)
                        
                        # Plot Q-learning (green)
                        ax.plot(episodes, q_mean, color='green', label='Q-Learning')
                        ax.fill_between(episodes, np.array(q_mean) - np.array(q_std), 
                                      np.array(q_mean) + np.array(q_std), color='green', alpha=0.2)
                        
                        # Plot Expected SARSA (red)
                        ax.plot(episodes, sarsa_mean, color='red', label='Expected SARSA')
                        ax.fill_between(episodes, np.array(sarsa_mean) - np.array(sarsa_std), 
                                      np.array(sarsa_mean) + np.array(sarsa_std), color='red', alpha=0.2)
                    except KeyError as e:
                        print(f"Warning: Missing data for {env_name}, {replay_setting}, ε={epsilon}, α={lr}: {e}")
                        ax.text(0.5, 0.5, 'Data not available', 
                              horizontalalignment='center', verticalalignment='center', 
                              transform=ax.transAxes)
                    
                    # Set title and labels for this subplot
                    ax.set_title(f"ε={epsilon}, α={lr}")
                    if i == len(epsilons) - 1:  # Bottom row
                        ax.set_xlabel("Episode")
                    if j == 0:  # Leftmost column
                        ax.set_ylabel("Return")
                    
                    # Add legend to the first subplot only
                    if i == 0 and j == 0:
                        ax.legend()
                    
                    # Add grid for better readability
                    ax.grid(True, linestyle='--', alpha=0.6)
            
            # Add overall title for the entire figure
            replay_text = "with Replay Buffer" if replay_setting == 'with_replay' else "without Replay Buffer"
            fig.suptitle(f"{env_name} {replay_text}", fontsize=16)
            
            # Optimize layout
            plt.tight_layout()
            plt.subplots_adjust(top=0.92)
            
            # Save the figure to a file
            os.makedirs("figures", exist_ok=True)
            filename = f"{env_name.replace('/', '_')}_{replay_setting}.png"
            plt.savefig(os.path.join("figures", filename), dpi=300)
            plt.close()
            
            print(f"Generated visualization for {env_name} {replay_text}")

# Example usage:
# plot_results_from_json("results/numerical_results.json")

In [4]:
plot_results_from_json("acrobat_numerical_results.json")

In [5]:
plot_results_from_json("assault_replay_numerical_results.json")

Generated visualization for ALE/Assault-ram-v5 with Replay Buffer
