In [1]:
import sys

sys.path.append("../../")
import time
import torch
from losses.basic_losses import CategoricalCrossentropyLoss, KLDivergenceLoss
from agents.random import RandomAgent
from agents.muzero import MuZeroAgent
from agent_configs.muzero_config import MuZeroConfig
from game_configs.tictactoe_config import TicTacToeConfig
from agents.tictactoe_expert import TicTacToeBestAgent
from modules.world_models.muzero_world_model import MuzeroWorldModel

# Ensure we use CPU for fairness/comparibility or GPU if available
device = "cpu"  # or "cuda" if available
print(f"Using device: {device}")

Using device: cpu


  from pkg_resources import resource_stream, resource_exists


# MuZero Benchmark (Iterative vs Batched)

In [2]:
params = {
    "num_simulations": 25,
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "n_step": 10,
    "root_dirichlet_alpha": 0.25,
    "residual_layers": [(24, 3, 1)],
    "reward_dense_layer_widths": [],
    "reward_conv_layers": [(16, 1, 1)],
    "actor_dense_layer_widths": [],
    "actor_conv_layers": [(16, 1, 1)],
    "critic_dense_layer_widths": [],
    "critic_conv_layers": [(16, 1, 1)],
    "to_play_dense_layer_widths": [],
    "to_play_conv_layers": [(16, 1, 1)],
    "known_bounds": [-1, 1],
    "support_range": None,
    "minibatch_size": 8,
    "replay_buffer_size": 100000,
    "gumbel": False,
    "gumbel_m": 16,
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "training_steps": 20000,  # Reduced for benchmark speed
    "transfer_interval": 1,
    "num_workers": 4,
    "world_model_cls": MuzeroWorldModel,
    "search_batch_size": 0,  # Iterative
    "use_virtual_mean": False,
    "virtual_loss": 3.0,
    "use_torch_compile": True,
    "use_mixed_precision": True,
    "use_quantization": False,
}

game_config = TicTacToeConfig()

In [None]:
print("--- Running MuZero Batched Search Max Fast ---")
params_batched = params.copy()
params_batched["num_workers"] = 4
params_batched["search_batch_size"] = 5
params_batched["use_virtual_mean"] = True
params_batched["use_mixed_precision"] = True
params_batched["use_torch_compile"] = True
params_batched["use_quantization"] = True
params_batched["qat"] = True
params_batched["transfer_interval"] = 100

# params_batched["num_envs_per_worker"] = 4

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_fast",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

--- Running MuZero Batched Search Max Fast ---
Using default save_intermediate_weights     : False
Using         training_steps                : 20000
Using default adam_epsilon                  : 1e-08
Using default momentum                      : 0.9
Using default learning_rate                 : 0.001
Using default clipnorm                      : 0
Using default optimizer                     : <class 'torch.optim.adam.Adam'>
Using default weight_decay                  : 0.0
Using default num_minibatches               : 1
Using default training_iterations           : 1
Using default lr_schedule_type              : none
Using default lr_schedule_steps             : []
Using default lr_schedule_steps             : []
Using default lr_schedule_values            : []
Using         use_mixed_precision           : True
Using         use_torch_compile             : True
Using default compile_mode                  : reduce-overhead
Using         minibatch_size                : 8
Using        

2026-01-28 20:13:20,022	INFO worker.py:2007 -- Started a local Ray instance.


Max size: 100000
Initializing stat 'score' with subkeys None
Initializing stat 'policy_loss' with subkeys None
Initializing stat 'value_loss' with subkeys None
Initializing stat 'reward_loss' with subkeys None
Initializing stat 'to_play_loss' with subkeys None
Initializing stat 'cons_loss' with subkeys None
Initializing stat 'loss' with subkeys None
Initializing stat 'test_score' with subkeys ['score', 'max_score', 'min_score']
Initializing stat 'episode_length' with subkeys None
Initializing stat 'policy_entropy' with subkeys None
Initializing stat 'value_diff' with subkeys None
Initializing stat 'policy_improvement' with subkeys ['network', 'search']
Initializing stat 'root_children_values' with subkeys None
Initializing stat 'test_score_vs_random' with subkeys ['score', 'player_1_score', 'player_2_score', 'player_1_win%', 'player_2_win%']
Initializing stat 'test_score_vs_tictactoe_expert' with subkeys ['score', 'player_1_score', 'player_2_score', 'player_1_win%', 'player_2_win%']
[

[36m(MuZeroWorker pid=5381)[0m   from pkg_resources import resource_stream, resource_exists


Broadcasting initial weights to workers...
Starting initial batch of games...
Size: 0
0
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[0, 2, 3, 6, 5],
        [2, 3, 6, 5, 8],
        [3, 6, 5, 8, 4],
        [6, 5, 8, 4, 0],
        [6, 5, 8, 4, 0],
        [8, 4, 0, 7, 0],
        [8, 4, 0, 7, 0],
        [4, 0, 1, 7, 0]])
target value tensor([[ 0.9415, -0.9510,  0.9606, -0.9703,  0.9801, -0.9900],
        [-0.9510,  0.9606, -0.9703,  0.9801, -0.9900,  1.0000],
        [ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [-0.9900,  1.0000, 

[33m(raylet)[0m [2026-01-28 20:13:30,024 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3825 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[36m(MuZeroWorker pid=5383)[0m   from pkg_resources import resource_stream, resource_exists[32m [repeated 3x across cluster][0m


Size: 679
Size: 688
Size: 698
Size: 704
Size: 713
Size: 721
Size: 731
Size: 740
Size: 750
Size: 758
Size: 768
Size: 778
Size: 786
Size: 796
Size: 805
Size: 813
Size: 820
Size: 827
Size: 837
Size: 844
Size: 852
plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting to_play_loss
plotting cons_loss
plotting loss
plotting episode_length
plotting root_children_values
plotting q_loss
plotting sigma_loss
plotting vqvae_commitment_cost


[36m(MuZeroWorker pid=5386)[0m   if not check_min_max_valid(min_val, max_val):


[36m(MuZeroWorker pid=5386)[0m Worker 2: Compiling INT8 model...
[36m(MuZeroWorker pid=5383)[0m Hidden state shape: (1, 24, 3, 3)[32m [repeated 6x across cluster][0m
[36m(MuZeroWorker pid=5383)[0m encoder input shape (1, 18, 3, 3)[32m [repeated 3x across cluster][0m
plotting policy_entropy
plotting value_diff
plotting policy_improvement
  subkey network
  subkey search
plotting latent viz latent_root using umap
  Saving latent viz to checkpoints/muzero_batched_bench_fast/graphs/muzero_batched_bench_fast_latent_root_umap.png
Size: 860
100
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[1, 4, 0, 6, 7],
        [6, 7, 0, 2, 4],
        [0, 1, 6, 3, 4],
        [7, 0, 0, 6,

[33m(raylet)[0m [2026-01-28 20:13:40,100 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3817 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[36m(MuZeroWorker pid=5380)[0m   if not check_min_max_valid(min_val, max_val):[32m [repeated 3x across cluster][0m


plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting to_play_loss
plotting cons_loss
plotting loss
plotting episode_length
plotting root_children_values
plotting q_loss
plotting sigma_loss
plotting vqvae_commitment_cost


[36m(MuZeroWorker pid=5381)[0m   if not check_min_max_valid(min_val, max_val):
[36m(MuZeroWorker pid=5386)[0m   if not check_min_max_valid(min_val, max_val):


plotting policy_entropy
plotting value_diff
plotting policy_improvement
  subkey network
  subkey search
plotting latent viz latent_root using umap
  Saving latent viz to checkpoints/muzero_batched_bench_fast/graphs/muzero_batched_bench_fast_latent_root_umap.png
200
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[0, 7, 3, 5, 6],
        [0, 5, 1, 2, 0],
        [3, 4, 8, 5, 6],
        [0, 6, 3, 2, 1],
        [1, 0, 7, 6, 4],
        [5, 0, 7, 6, 4],
        [1, 6, 3, 5, 0],
        [5, 3, 0, 6, 4]])
target value tensor([[ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0

[36m(pid=gcs_server)[0m [2026-01-28 20:13:49,448 E 5372 47028] (gcs_server) gcs_server.cc:303: Failed to establish connection to the event+metrics exporter agent. Events and metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
[33m(raylet)[0m [2026-01-28 20:13:49,990 E 5377 47139] (raylet) main.cc:1032: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
[33m(raylet)[0m [2026-01-28 20:13:50,176 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3797 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[36m(MuZeroWorker pid=5381)[0m [2026-01-28 20:13:50,665 E 5381 47425] core_worker_process.cc:842: Failed to establish connection to the metrics exporter 

plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting to_play_loss
plotting cons_loss
plotting loss
plotting episode_length
plotting root_children_values
plotting q_loss
plotting sigma_loss
plotting vqvae_commitment_cost
plotting policy_entropy
plotting value_diff
plotting policy_improvement
  subkey network
  subkey search
plotting latent viz latent_root using umap
  Saving latent viz to checkpoints/muzero_batched_bench_fast/graphs/muzero_batched_bench_fast_latent_root_umap.png
300
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[8, 7, 5, 1, 0],
        [0, 1, 5, 4, 8],
        [4, 5, 1, 2, 8],
        [8, 1, 2, 0, 1],
        [7, 6, 0, 2, 1],
    

[36m(MuZeroWorker pid=5383)[0m   if not check_min_max_valid(min_val, max_val):[32m [repeated 2x across cluster][0m
[33m(raylet)[0m [2026-01-28 20:14:00,259 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3824 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[36m(pid=5385)[0m [2026-01-28 20:13:50,760 E 5385 47565] core_worker_process.cc:842: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14[32m [repeated 7x across cluster][0m


plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting to_play_loss
plotting cons_loss
plotting loss
plotting episode_length
plotting root_children_values
plotting q_loss
plotting sigma_loss
plotting vqvae_commitment_cost
plotting policy_entropy
plotting value_diff
plotting policy_improvement
  subkey network
  subkey search
plotting latent viz latent_root using umap
  Saving latent viz to checkpoints/muzero_batched_bench_fast/graphs/muzero_batched_bench_fast_latent_root_umap.png
400
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[2, 8, 4, 5, 3],
        [8, 3, 1, 7, 0],
        [7, 3, 5, 0, 6],
        [7, 1, 4, 5, 2],
        [0, 1, 8, 3, 5],
    

[33m(raylet)[0m [2026-01-28 20:14:10,346 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3821 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:14:20,444 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.381 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.


plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting to_play_loss
plotting cons_loss
plotting loss
plotting episode_length
plotting root_children_values
plotting q_loss
plotting sigma_loss
plotting vqvae_commitment_cost
plotting policy_entropy
plotting value_diff
plotting policy_improvement
  subkey network
  subkey search
plotting latent viz latent_root using umap
  Saving latent viz to checkpoints/muzero_batched_bench_fast/graphs/muzero_batched_bench_fast_latent_root_umap.png
500
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[3, 2, 5, 6, 8],
        [4, 8, 1, 0, 3],
        [1, 2, 7, 4, 0],
        [0, 4, 5, 2, 1],
        [4, 8, 7, 0, 0],
    

[33m(raylet)[0m [2026-01-28 20:14:30,535 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3799 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:14:40,616 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3775 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.


plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting to_play_loss
plotting cons_loss
plotting loss
plotting episode_length
plotting root_children_values
plotting q_loss
plotting sigma_loss
plotting vqvae_commitment_cost
plotting policy_entropy
plotting value_diff
plotting policy_improvement
  subkey network
  subkey search
plotting latent viz latent_root using umap
  Saving latent viz to checkpoints/muzero_batched_bench_fast/graphs/muzero_batched_bench_fast_latent_root_umap.png
600
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[6, 7, 0, 2, 5],
        [5, 3, 0, 7, 6],
        [2, 3, 5, 4, 8],
        [2, 7, 5, 0, 8],
        [1, 3, 5, 0, 5],
    

KeyboardInterrupt: 

[33m(raylet)[0m [2026-01-28 20:14:50,638 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3748 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:15:00,710 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3748 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:15:10,789 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3748 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:15:20,874 E 5377 47155] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-13-19_172233_5354 is over 95% full, available space: 20.3752 GB; capacity: 460.432 GB.

: 

In [None]:
print("--- Running MuZero Iterative Search (Batch=0) ---")
env_iter = TicTacToeConfig().make_env()
config_iter = MuZeroConfig(config_dict=params, game_config=game_config)
config_iter.search_batch_size = 0  # Explicitly set

agent_iter = MuZeroAgent(
    env=env_iter,
    config=config_iter,
    name="muzero_iterative_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_iter.checkpoint_interval = 100
agent_iter.test_interval = 1000
agent_iter.test_trials = 100

start_time = time.time()
agent_iter.train()
end_time = time.time()
print(f"MuZero Iterative Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Iterative Search (Batch=1) ---")
env_iter = TicTacToeConfig().make_env()
config_iter = MuZeroConfig(config_dict=params, game_config=game_config)
config_iter.search_batch_size = 1  # Explicitly set

agent_iter = MuZeroAgent(
    env=env_iter,
    config=config_iter,
    name="muzero_iterative_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_iter.checkpoint_interval = 100
agent_iter.test_interval = 1000
agent_iter.test_trials = 100

start_time = time.time()
agent_iter.train()
end_time = time.time()
print(f"MuZero Iterative Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Batched Search (Batch=5) ---")
params_batched = params.copy()
params_batched["search_batch_size"] = 5

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)
config_batch.search_batch_size = 5  # Explicitly set

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_size_5",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Batched Search (Batch=5) ---")
params_batched = params.copy()
params_batched["search_batch_size"] = 5

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)
config_batch.search_batch_size = 5  # Explicitly set

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_size_5",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running MuZero Batched Search (Batch=5) Virtual Mean ---")
params_batched = params.copy()
params_batched["search_batch_size"] = 5
params_batched["use_virtual_mean"] = True

env_batch = TicTacToeConfig().make_env()
config_batch = MuZeroConfig(config_dict=params_batched, game_config=game_config)
config_batch.search_batch_size = 5  # Explicitly set

agent_batch = MuZeroAgent(
    env=env_batch,
    config=config_batch,
    name="muzero_batched_bench_size_5_virtual_mean_1",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_batch.checkpoint_interval = 100
agent_batch.test_interval = 1000
agent_batch.test_trials = 100

start_time = time.time()
agent_batch.train()
end_time = time.time()
print(f"MuZero Batched Time: {end_time - start_time:.2f}s")

# Gumbel MuZero Benchmark (Iterative vs Batched)

In [None]:
params = {
    "num_simulations": 25,
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "n_step": 10,
    "root_dirichlet_alpha": 0.25,
    "residual_layers": [(24, 3, 1)],
    "reward_dense_layer_widths": [],
    "reward_conv_layers": [(16, 1, 1)],
    "actor_dense_layer_widths": [],
    "actor_conv_layers": [(16, 1, 1)],
    "critic_dense_layer_widths": [],
    "critic_conv_layers": [(16, 1, 1)],
    "to_play_dense_layer_widths": [],
    "to_play_conv_layers": [(16, 1, 1)],
    "known_bounds": [-1, 1],
    "support_range": None,
    "minibatch_size": 8,
    "replay_buffer_size": 100000,
    "gumbel": True,
    "gumbel_m": 4,
    "policy_loss_function": KLDivergenceLoss(),
    "training_steps": 20000,  # Reduced for benchmark speed
    "transfer_interval": 1,
    "num_workers": 4,
    "world_model_cls": MuzeroWorldModel,
    "search_batch_size": 0,  # Iterative
    "use_virtual_mean": False,
    "virtual_loss": 3.0,
}

game_config = TicTacToeConfig()

params_gumbel = params.copy()

In [None]:
print("--- Running Gumbel MuZero Iterative Search (Batch=1) ---")
params_gumbel["search_batch_size"] = 0

env_gumbel = TicTacToeConfig().make_env()
config_gumbel = MuZeroConfig(config_dict=params_gumbel, game_config=game_config)

agent_gumbel = MuZeroAgent(
    env=env_gumbel,
    config=config_gumbel,
    name="gumbel_iterative_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_gumbel.checkpoint_interval = 100
agent_gumbel.test_interval = 1000
agent_gumbel.test_trials = 100

start_time = time.time()
agent_gumbel.train()
end_time = time.time()
print(f"Gumbel Iterative Time: {end_time - start_time:.2f}s")

In [None]:
print("--- Running Gumbel MuZero Batched Search (Batch=5) ---")
params_gumbel_batch = params_gumbel.copy()
params_gumbel_batch["search_batch_size"] = 5
params_gumbel_batch["use_virtual_mean"] = True

env_gumbel_batch = TicTacToeConfig().make_env()
config_gumbel_batch = MuZeroConfig(
    config_dict=params_gumbel_batch, game_config=game_config
)

agent_gumbel_batch = MuZeroAgent(
    env=env_gumbel_batch,
    config=config_gumbel_batch,
    name="gumbel_batched_bench",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)
agent_gumbel_batch.checkpoint_interval = 100
agent_gumbel_batch.test_interval = 1000
agent_gumbel_batch.test_trials = 100

start_time = time.time()
agent_gumbel_batch.train()
end_time = time.time()
print(f"Gumbel Batched Time: {end_time - start_time:.2f}s")