In [1]:
%matplotlib inline
import sys
import os
import torch
import shutil
import matplotlib.pyplot as plt
import numpy as np

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from agent_configs.muzero_config import MuZeroConfig
from agents.muzero import MuZeroAgent
from modules.world_models.muzero_world_model import MuzeroWorldModel
from game_configs.cartpole_config import CartPoleConfig
from losses.basic_losses import CategoricalCrossentropyLoss, KLDivergenceLoss, MSELoss
from torch.optim import SGD, Adam
import gymnasium as gym

In [2]:
# 1. Setup Config
game_config = CartPoleConfig()
config_dict = {
    "num_simulations": 50,  # 25
    "per_alpha": 0.0,  # 0.5-1.0, 0.6
    "per_beta": 0.0,  # 1.0
    "per_beta_final": 0.0,  # 1.0
    "root_dirichlet_alpha": 0.5,  # 0.5 - 1.0, 0.3
    "dense_layer_widths": [
        64,
        8,
    ],  # 1-3 layers, 64-128, with a dynamics size of 8-16
    "dynamics_dense_layer_widths": [
        32,
        8,
    ],  # 1-3 layers, 64-128, with a dynamics size of 8-16
    "residual_layers": [],
    "reward_dense_layer_widths": [],
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [],
    "actor_conv_layers": [],
    "critic_dense_layer_widths": [],
    "critic_conv_layers": [],
    "to_play_dense_layer_widths": [],
    "to_play_conv_layers": [],
    "known_bounds": [1, 500],
    "support_range": 31,  # 601 but only for value and not reward, 300
    "minibatch_size": 64,  # 64-128
    "replay_buffer_size": 10000,  # 10k -20k
    "min_replay_buffer_size": 500,
    "gumbel": True,
    "gumbel_m": 2,
    "policy_loss_function": KLDivergenceLoss(),  # KLDivergenceLoss()
    "reward_loss_function": CategoricalCrossentropyLoss(),  # CategoricalCrossentropyLoss(),
    "value_loss_function": CategoricalCrossentropyLoss(),  # CategoricalCrossentropyLoss(),
    "training_steps": 20000,
    "transfer_interval": 1,  # ?
    "num_workers": 4,
    "discount_factor": 0.999,  # 0.99 too low? but good for alpha zero? 0.997 - 0.999
    "unroll_steps": 5,  # 5
    "n_step": 50,  # 5, 10, 50 or 500 (50 or 500 recommended)
    # "temperatures": [3, 2, 1, 0.5, 0.25, 0.125, 0.075, 0.01, 0.0],
    # "temperature_updates": [100, 200, 300, 400, 500, 600, 700, 800],
    # "temperature_with_training_steps": True,
    "temperatures": [1.0, 0.0],
    "temperature_updates": [15],
    "temperature_with_training_steps": False,
    "learning_rate": 0.001,
    "optimizer": Adam,
    "value_loss_factor": 0.25,  # 0.25
    "reanalyze_ratio": 0.0,  # 100% should be reanalyzed (how?)
    "injection_frac": 0.0,
    "value_prefix": False,
    "consistency_loss_factor": 0.0,
    "stochastic": True,
    "num_chance": 32,
    "world_model_cls": MuzeroWorldModel,
}

config = MuZeroConfig(config_dict, game_config)

Using default save_intermediate_weights     : False
Using         training_steps                : 20000
Using default adam_epsilon                  : 1e-08
Using default momentum                      : 0.9
Using         learning_rate                 : 0.001
Using default clipnorm                      : 0
Using         optimizer                     : <class 'torch.optim.adam.Adam'>
Using default weight_decay                  : 0.0
Using default num_minibatches               : 1
Using default training_iterations           : 1
Using default lr_schedule_type              : none
Using default lr_schedule_steps             : []
Using default lr_schedule_values            : []
Using         minibatch_size                : 64
Using         replay_buffer_size            : 10000
Using         min_replay_buffer_size        : 500
Using         n_step                        : 50
Using         discount_factor               : 0.999
Using         per_alpha                     : 0.0
Using         per_b

In [3]:
# 2. Setup Agent
env = gym.make("CartPole-v1")
agent = MuZeroAgent(env, config, name="refactored_unroll_processor_test", device="cpu")

[refactored_unroll_processor_test] Using device: cpu
Observation dimensions: torch.Size([4])
Num actions: 2 (Discrete: True)
Making test env...
MARL Agent 'refactored_unroll_processor_test' initialized. Test agents: []
Hidden state shape: (64, 8)
Hidden state shape: (64, 8)
encoder input shape (64, 8)
Hidden state shape: (64, 8)
Hidden state shape: (64, 8)
encoder input shape (64, 8)
Max size: 10000
Initializing stat 'score' with subkeys None
Initializing stat 'policy_loss' with subkeys None
Initializing stat 'value_loss' with subkeys None
Initializing stat 'reward_loss' with subkeys None
Initializing stat 'to_play_loss' with subkeys None
Initializing stat 'cons_loss' with subkeys None
Initializing stat 'loss' with subkeys None
Initializing stat 'test_score' with subkeys ['score', 'max_score', 'min_score']
Initializing stat 'episode_length' with subkeys None
Initializing stat 'num_codes' with subkeys None
Initializing stat 'chance_probs' with subkeys None
Initializing stat 'chance_entr

In [None]:
agent.checkpoint_interval = 100
agent.test_interval = 1000
agent.test_trials = 10

agent.train()

[Worker 1] Starting self-play...
[Worker 0] Starting self-play...
[Worker 2] Starting self-play...
[Worker 3] Starting self-play...
0
actions shape torch.Size([64, 5])
target value shape torch.Size([64, 6])
predicted values shape torch.Size([64, 6, 63])
target rewards shape torch.Size([64, 6])
predicted rewards shape torch.Size([64, 6, 63])
target qs shape torch.Size([64, 6])
predicted qs shape torch.Size([64, 6, 63])
target to plays shape torch.Size([64, 6, 1])
predicted to_plays shape torch.Size([64, 6, 1])
masks shape torch.Size([64, 6]) torch.Size([64, 6])
actions tensor([[1, 0, 1, 1, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 1, 0, 1],
        [1, 1, 0, 1, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [0, 1, 0, 1, 0],
        [0, 0, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [0, 0, 1, 0, 1],
        [1, 1, 0, 1, 0],
        [0, 0, 1, 1, 0],
        [1, 1, 0, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 0, 1, 0, 1],


Process Process-10:
Process Process-3:
Process Process-2:
Process Process-4:
Process Process-1:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/agents/muzero.py", line 273, in worker_fn
    score, num_steps = self.play_game(
                       ^^^^^^^^^^^^^^^
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/agents/muzero.py", line 883, in play_game
    prediction = self.predict(
                 ^^^^^^^^^^^^^
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/agents/muzero.py", line 829, in predict
    root_value, exploratory_policy, target_policy, best_action =

KeyboardInterrupt: 

: 

In [None]:
# 5. Verify Stats and Plot
print("\nVerifying Stats...")
stats = agent.stats.stats 

if "chance_probs" in stats:
    chance_probs_data = stats["chance_probs"]
    print(f"'chance_probs' stat found. Shape: {chance_probs_data.shape}")
    
    # Get the latest distribution
    latest_probs = chance_probs_data[-1].tolist()
    print(f"Latest probability distribution (first 5): {latest_probs[:5]}...")
    
    # We can try to use the agent's internal plotting mechanism or plot manually here
    # Manual plot for immediate feedback in notebook:
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(latest_probs)), latest_probs, color='skyblue')
    plt.xlabel('Chance Codes')
    plt.ylabel('Probability')
    plt.title('Chance Code Probabilities (Last Step)')
    plt.ylim(0, 1.0)
    plt.show()
    
    # Also trigger the agent's plot function to verify it works with the new BAR type
    print("Triggering agent.stats.plot_graphs()...")
    agent.stats.plot_graphs()
else:
    print("Missing key: chance_probs")