In [None]:
import matplotlib
# Force the use of an interactive backend (must be set before importing pyplot)
matplotlib.use('TkAgg')

import json
import ast
import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # Required for 3D plotting

# (Optional) Enable interactive mode
plt.ion()

# Load the JSON file containing Q-values
json_path = '/home/beamkeerati/DRL-HW2/CartPole_4.2.0/q_value/Stabilize/MC/MC_900_10_5.0_10_10.json'
with open(json_path, 'r') as f:
    data = json.load(f)

# Extract the q_values dictionary and list of state strings
q_values = data['q_values']
states_str = list(q_values.keys())

# Initialize lists for cart position and pole angle
cart_pos = []    # x-axis: pose_cart
pole_angle = []  # y-axis: pose_pole

# Determine number of actions dynamically based on the first Q-value entry
sample_q = next(iter(q_values.values()))
num_actions = len(sample_q)

# Prepare a list for each action's Q-values
action_q = [[] for _ in range(num_actions)]

# Iterate over each state and extract the required values
for state_str in states_str:
    state_tuple = ast.literal_eval(state_str)
    # For CartPole, assume:
    #   state_tuple[0] = cart position (pose_cart)
    #   state_tuple[2] = pole angle (pose_pole)
    cart_pos.append(state_tuple[0])
    pole_angle.append(state_tuple[2])
    for i in range(num_actions):
        action_q[i].append(q_values[state_str][i])

# Convert lists to NumPy arrays
cart_pos = np.array(cart_pos)
pole_angle = np.array(pole_angle)
for i in range(num_actions):
    action_q[i] = np.array(action_q[i])

# Create a regular grid from unique cart positions and pole angles
unique_cart = np.sort(np.unique(cart_pos))
unique_pole = np.sort(np.unique(pole_angle))
X, Y = np.meshgrid(unique_cart, unique_pole)

# Determine subplot grid size (up to 3 columns per row)
cols = min(num_actions, 3)
rows = math.ceil(num_actions / cols)
fig = plt.figure(figsize=(5 * cols, 4 * rows))

# For each action, create a 3D surface plot
for i in range(num_actions):
    # Build a dictionary to map (cart, pole) to Q-value for this action
    q_dict = {}
    for cp, pa, q in zip(cart_pos, pole_angle, action_q[i]):
        q_dict[(cp, pa)] = q
    
    # Create a Z grid by iterating over the mesh grid points
    Z = np.empty(X.shape, dtype=float)
    for r in range(X.shape[0]):
        for c in range(X.shape[1]):
            key = (X[r, c], Y[r, c])
            Z[r, c] = q_dict.get(key, np.nan)  # Fill missing values with NaN if grid is incomplete

    # Plot the surface for the current action
    ax = fig.add_subplot(rows, cols, i + 1, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none')
    ax.set_xlabel('Cart Position')
    ax.set_ylabel('Pole Angle')
    ax.set_zlabel('Q-value')
    ax.set_title(f'Action {i} Q-values Surface')
    fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)

plt.suptitle('3D Surface Visualization of Q-values: Cart Position vs. Pole Angle', fontsize=16)
plt.tight_layout()

# Use plt.show(block=True) to keep the window open and interactive
plt.show(block=True)


In [16]:
import json
import ast

# Path to your Q-value JSON file
json_path = '/home/beamkeerati/DRL-HW2/CartPole_4.2.0/q_value/Stabilize/Q_Learning/Q_Learning_900_20_1.0_10_10.json'

# Load the JSON file
with open(json_path, 'r') as f:
    data = json.load(f)

# Extract the q_values dictionary
q_values = data['q_values']

# Initialize variables to track the maximum Q-value and corresponding state and action index
max_q = -float('inf')
best_state = None
best_action = None

# Iterate through each state and its corresponding Q-value list
for state_str, q_list in q_values.items():
    # q_list is a list of Q-values for each action in this state
    for action_index, q in enumerate(q_list):
        if q > max_q:
            max_q = q
            best_state = state_str
            best_action = action_index

# Print the result
print("The maximum Q value is:", max_q)
print("Achieved at state:", best_state)
print("For action index:", best_action)


The maximum Q value is: 0.01880221734456608
Achieved at state: (2, 0, 0, -2)
For action index: 9
