In [None]:
%matplotlib inline
import sys
import os
import torch
import shutil
import matplotlib.pyplot as plt
import numpy as np

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from agent_configs.muzero_config import MuZeroConfig
from agents.muzero import MuZeroAgent
from modules.world_models.muzero_world_model import MuzeroWorldModel
from game_configs.cartpole_config import CartPoleConfig
from losses.basic_losses import CategoricalCrossentropyLoss, KLDivergenceLoss, MSELoss
from torch.optim import SGD, Adam
import gymnasium as gym

: 

In [None]:
# 1. Setup Config
game_config = CartPoleConfig()
config_dict = {
    "num_simulations": 50,  # 25
    "per_alpha": 0.0,  # 0.5-1.0, 0.6
    "per_beta": 0.0,  # 1.0
    "per_beta_final": 0.0,  # 1.0
    "root_dirichlet_alpha": 0.5,  # 0.5 - 1.0, 0.3
    "dense_layer_widths": [
        64,
        8,
    ],  # 1-3 layers, 64-128, with a dynamics size of 8-16
    "dynamics_dense_layer_widths": [
        32,
        8,
    ],  # 1-3 layers, 64-128, with a dynamics size of 8-16
    "residual_layers": [],
    "reward_dense_layer_widths": [],
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [],
    "actor_conv_layers": [],
    "critic_dense_layer_widths": [],
    "critic_conv_layers": [],
    "to_play_dense_layer_widths": [],
    "to_play_conv_layers": [],
    "known_bounds": [1, 500],
    "support_range": 31,  # 601 but only for value and not reward, 300
    "minibatch_size": 64,  # 64-128
    "replay_buffer_size": 10000,  # 10k -20k
    "min_replay_buffer_size": 500,
    "gumbel": True,
    "gumbel_m": 2,
    "policy_loss_function": KLDivergenceLoss(),  # KLDivergenceLoss()
    "reward_loss_function": CategoricalCrossentropyLoss(),  # CategoricalCrossentropyLoss(),
    "value_loss_function": CategoricalCrossentropyLoss(),  # CategoricalCrossentropyLoss(),
    "training_steps": 20000,
    "transfer_interval": 1,  # ?
    "num_workers": 4,
    "discount_factor": 0.999,  # 0.99 too low? but good for alpha zero? 0.997 - 0.999
    "unroll_steps": 5,  # 5
    "n_step": 50,  # 5, 10, 50 or 500 (50 or 500 recommended)
    # "temperatures": [3, 2, 1, 0.5, 0.25, 0.125, 0.075, 0.01, 0.0],
    # "temperature_updates": [100, 200, 300, 400, 500, 600, 700, 800],
    # "temperature_with_training_steps": True,
    "temperatures": [1.0, 0.0],
    "temperature_updates": [15],
    "temperature_with_training_steps": False,
    "learning_rate": 0.001,
    "optimizer": Adam,
    "value_loss_factor": 0.25,  # 0.25
    "reanalyze_ratio": 0.0,  # 100% should be reanalyzed (how?)
    "injection_frac": 0.0,
    "value_prefix": False,
    "consistency_loss_factor": 0.0,
    "stochastic": True,
    "num_chance": 32,
    "world_model_cls": MuzeroWorldModel,
}

config = MuZeroConfig(config_dict, game_config)

In [None]:
# 2. Setup Agent
env = gym.make("CartPole-v1")
agent = MuZeroAgent(env, config, name="test_chance_prob_nb")

In [None]:
agent.checkpoint_interval = 100
agent.test_interval = 1000
agent.test_trials = 10

agent.train()

In [None]:
# 5. Verify Stats and Plot
print("\nVerifying Stats...")
stats = agent.stats.stats 

if "chance_probs" in stats:
    chance_probs_data = stats["chance_probs"]
    print(f"'chance_probs' stat found. Shape: {chance_probs_data.shape}")
    
    # Get the latest distribution
    latest_probs = chance_probs_data[-1].tolist()
    print(f"Latest probability distribution (first 5): {latest_probs[:5]}...")
    
    # We can try to use the agent's internal plotting mechanism or plot manually here
    # Manual plot for immediate feedback in notebook:
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(latest_probs)), latest_probs, color='skyblue')
    plt.xlabel('Chance Codes')
    plt.ylabel('Probability')
    plt.title('Chance Code Probabilities (Last Step)')
    plt.ylim(0, 1.0)
    plt.show()
    
    # Also trigger the agent's plot function to verify it works with the new BAR type
    print("Triggering agent.stats.plot_graphs()...")
    agent.stats.plot_graphs()
else:
    print("Missing key: chance_probs")