In [1]:
import sys
sys.path.append("../../")
import time
import torch
from losses.basic_losses import CategoricalCrossentropyLoss, KLDivergenceLoss
from agents.random import RandomAgent
from agents.muzero import MuZeroAgent
from agent_configs.muzero_config import MuZeroConfig
from game_configs.tictactoe_config import TicTacToeConfig
from agents.tictactoe_expert import TicTacToeBestAgent
from modules.world_models.muzero_world_model import MuzeroWorldModel

# Ensure we use CPU for fairness/comparibility or GPU if available
device = "cpu" # or "cuda" if available
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


# MuZero Benchmark (Iterative vs Batched)

In [2]:
params = {
    "num_simulations": 25,
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "n_step": 10,
    "root_dirichlet_alpha": 0.25,
    "residual_layers": [(24, 3, 1)],
    "reward_dense_layer_widths": [],
    "reward_conv_layers": [(16, 1, 1)],
    "actor_dense_layer_widths": [],
    "actor_conv_layers": [(16, 1, 1)],
    "critic_dense_layer_widths": [],
    "critic_conv_layers": [(16, 1, 1)],
    "to_play_dense_layer_widths": [],
    "to_play_conv_layers": [(16, 1, 1)],
    "known_bounds": [-1, 1],
    "support_range": None,
    "minibatch_size": 8,
    "replay_buffer_size": 100000,
    "gumbel": False,
    "gumbel_m": 16,
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "training_steps": 20000, # Reduced for benchmark speed
    "transfer_interval": 1,
    "num_workers": 4,
    "world_model_cls": MuzeroWorldModel,
    "search_batch_size": 0, # Iterative
    "use_virtual_mean": False,
    "virtual_loss": 3.0,
    "use_torch_compile": True,
    "use_mixed_precision": True,
    "use_quantization": False
}

game_config = TicTacToeConfig()


In [3]:
print("--- Running MuZero Batched Search Max Fast ---")
params_batched = params.copy()
params_batched["num_workers"] = 4
params_batched["search_batch_size"] = 5 
params_batched["use_virtual_mean"] = True
params_batched["use_mixed_precision"] = True
params_batched["use_torch_compile"] = True
params_batched["use_quantization"] = True
# params_batched["num_envs_per_worker"] = 4

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_fast",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

--- Running MuZero Batched Search Max Fast ---
Using default save_intermediate_weights     : False
Using         training_steps                : 20000
Using default adam_epsilon                  : 1e-08
Using default momentum                      : 0.9
Using default learning_rate                 : 0.001
Using default clipnorm                      : 0
Using default optimizer                     : <class 'torch.optim.adam.Adam'>
Using default weight_decay                  : 0.0
Using default num_minibatches               : 1
Using default training_iterations           : 1
Using default lr_schedule_type              : none
Using default lr_schedule_steps             : []
Using default lr_schedule_steps             : []
Using default lr_schedule_values            : []
Using         use_mixed_precision           : True
Using         use_torch_compile             : True
Using default compile_mode                  : reduce-overhead
Using         minibatch_size                : 8
Using        



Started recording episode 0 to ./videos/muzero_batched_bench_fast/2/episode_000000.mp4Started recording episode 0 to ./videos/muzero_batched_bench_fast/0/episode_000000.mp4

Started recording episode 0 to ./videos/muzero_batched_bench_fast/1/episode_000000.mp4
Started recording episode 0 to ./videos/muzero_batched_bench_fast/3/episode_000000.mp4
Stopped recording episode 0. Recorded 8 frames.learning

Size: 0
Stopped recording episode 0. Recorded 9 frames.
Size: 8
Stopped recording episode 0. Recorded 9 frames.
Size: 17
0
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[2, 3, 4, 7, 0],
        [2, 3, 4, 7, 0],
        [2, 3, 4, 7, 0],
        [2, 3, 4, 7, 0],
        [2, 3, 4, 7, 

  warn(


learning
100
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[4, 7, 0, 5, 8],
        [2, 0, 0, 6, 6],
        [7, 2, 0, 6, 6],
        [1, 6, 2, 8, 0],
        [0, 5, 1, 7, 2],
        [6, 0, 1, 4, 2],
        [1, 5, 3, 2, 0],
        [4, 3, 0, 7, 1]])
target value tensor([[ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.9900,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1

  warn(


learning
200
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[3, 1, 4, 8, 0],
        [1, 2, 0, 4, 8],
        [2, 7, 0, 8, 1],
        [6, 3, 7, 0, 8],
        [0, 7, 3, 8, 0],
        [0, 1, 4, 2, 5],
        [2, 6, 7, 4, 3],
        [5, 7, 0, 6, 8]])
target value tensor([[-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [-0.9900,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [ 0.9801, -0.9900,  1.0000,  0.0000,  0.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9227, -0.9321,  0.9415, -0

  warn(


learning
300
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[0, 7, 4, 8, 1],
        [0, 1, 4, 5, 3],
        [7, 6, 0, 5, 1],
        [8, 5, 0, 1, 3],
        [7, 1, 8, 0, 1],
        [0, 0, 3, 5, 1],
        [5, 8, 0, 0, 1],
        [5, 7, 0, 2, 6]])
target value tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9415, -0.9510,  0.9606, -0.9703,  0.9801, -0.9900],
        [-0.9900,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9415, -0.9510,  0.9606, -0.9703,  0.9801, -0.9900],
        [ 0.9801, -0.9900,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9801, -0.9900,  1.0000,  0

Process Process-2:
Process Process-1:
Process Process-3:
Process Process-4:
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.10/3.10.14/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/Cellar/python@3.10/3.10.14/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/experiments/tictactoe_muzero_nfsp/../../agents/muzero.py", line 428, in worker_fn
    score, num_steps = self.play_game(
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/experiments/tictactoe_muzero_nfsp/../../agents/muzero.py", line 1110, in play_game
    prediction = self.predict(
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/experiments/tictactoe_muzero_nfsp/../../agents/muzero.py", line 1057, in predict
   

KeyboardInterrupt: 

In [None]:
print("--- Running MuZero Iterative Search (Batch=0) ---")
env_iter = TicTacToeConfig().make_env()
config_iter = MuZeroConfig(config_dict=params, game_config=game_config)
config_iter.search_batch_size = 0 # Explicitly set

agent_iter = MuZeroAgent(
    env=env_iter,
    config=config_iter,
    name="muzero_iterative_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_iter.checkpoint_interval = 100
agent_iter.test_interval = 1000
agent_iter.test_trials = 100

start_time = time.time()
agent_iter.train()
end_time = time.time()
print(f"MuZero Iterative Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Iterative Search (Batch=1) ---")
env_iter = TicTacToeConfig().make_env()
config_iter = MuZeroConfig(config_dict=params, game_config=game_config)
config_iter.search_batch_size = 1 # Explicitly set

agent_iter = MuZeroAgent(
    env=env_iter,
    config=config_iter,
    name="muzero_iterative_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_iter.checkpoint_interval = 100
agent_iter.test_interval = 1000
agent_iter.test_trials = 100

start_time = time.time()
agent_iter.train()
end_time = time.time()
print(f"MuZero Iterative Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Batched Search (Batch=5) ---")
params_batched = params.copy()
params_batched["search_batch_size"] = 5 

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)
config_batch.search_batch_size = 5 # Explicitly set

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_size_5",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Batched Search (Batch=5) ---")
params_batched = params.copy()
params_batched["search_batch_size"] = 5 

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)
config_batch.search_batch_size = 5 # Explicitly set

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_size_5",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Batched Search (Batch=5) Virtual Mean ---")
params_batched = params.copy()
params_batched["search_batch_size"] = 5 
params_batched["use_virtual_mean"] = True

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)
config_batch.search_batch_size = 5 # Explicitly set

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_size_5_virtual_mean_1",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

# Gumbel MuZero Benchmark (Iterative vs Batched)

In [None]:
params = {
    "num_simulations": 25,
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "n_step": 10,
    "root_dirichlet_alpha": 0.25,
    "residual_layers": [(24, 3, 1)],
    "reward_dense_layer_widths": [],
    "reward_conv_layers": [(16, 1, 1)],
    "actor_dense_layer_widths": [],
    "actor_conv_layers": [(16, 1, 1)],
    "critic_dense_layer_widths": [],
    "critic_conv_layers": [(16, 1, 1)],
    "to_play_dense_layer_widths": [],
    "to_play_conv_layers": [(16, 1, 1)],
    "known_bounds": [-1, 1],
    "support_range": None,
    "minibatch_size": 8,
    "replay_buffer_size": 100000,
    "gumbel": True,
    "gumbel_m": 4,
    "policy_loss_function": KLDivergenceLoss(),
    "training_steps": 20000, # Reduced for benchmark speed
    "transfer_interval": 1,
    "num_workers": 4,
    "world_model_cls": MuzeroWorldModel,
    "search_batch_size": 0, # Iterative
    "use_virtual_mean": False,
    "virtual_loss": 3.0,
}

game_config = TicTacToeConfig()

params_gumbel = params.copy()


In [None]:
print("--- Running Gumbel MuZero Iterative Search (Batch=1) ---")
params_gumbel["search_batch_size"] = 0

env_gumbel = TicTacToeConfig().make_env()
config_gumbel = MuZeroConfig(config_dict=params_gumbel, game_config=game_config)

agent_gumbel = MuZeroAgent(
    env=env_gumbel,
    config=config_gumbel,
    name="gumbel_iterative_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_gumbel.checkpoint_interval = 100
agent_gumbel.test_interval = 1000
agent_gumbel.test_trials = 100

start_time = time.time()
agent_gumbel.train()
end_time = time.time()
print(f"Gumbel Iterative Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running Gumbel MuZero Batched Search (Batch=5) ---")
params_gumbel_batch = params_gumbel.copy()
params_gumbel_batch["search_batch_size"] = 5
params_gumbel_batch["use_virtual_mean"] = True

env_gumbel_batch = TicTacToeConfig().make_env()
config_gumbel_batch = MuZeroConfig(config_dict=params_gumbel_batch, game_config=game_config)

agent_gumbel_batch = MuZeroAgent(
    env=env_gumbel_batch,
    config=config_gumbel_batch,
    name="gumbel_batched_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_gumbel_batch.checkpoint_interval = 100
agent_gumbel_batch.test_interval = 1000
agent_gumbel_batch.test_trials = 100

start_time = time.time()
agent_gumbel_batch.train()
end_time = time.time()
print(f"Gumbel Batched Time: {end_time - start_time:.2f}s")