In [1]:
import sys
import os
import time
import torch
import gymnasium as gym
from tabulate import tabulate
import copy

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# Imports
from game_configs.tictactoe_config import TicTacToeConfig
from agent_configs.muzero_config import MuZeroConfig
from agents.tictactoe_expert import TicTacToeBestAgent
from agents.random import RandomAgent

# Dynamic imports for agents
from agents.muzero import MuZeroAgent as MuZeroRay
from agents.muzero_tmp import MuZeroAgent as MuZeroTorchMP
from modules.world_models.muzero_world_model import MuzeroWorldModel
from losses.basic_losses import CategoricalCrossentropyLoss

print("Imports complete. Device check:")
if torch.cuda.is_available():
    print("CUDA Available")
else:
    print("Using CPU")

  from pkg_resources import resource_stream, resource_exists


Imports complete. Device check:
Using CPU


In [None]:
class Benchmark:
    def __init__(self):
        self.device = torch.device("cpu")

        self.game_config = TicTacToeConfig()

        # Base Params
        # self.base_params = {
        #     "search_batch_size": 1,
        #     "use_virtual_mean": True,
        #     "use_mixed_precision": True,
        #     "compile": True,
        #     "use_quantization": True,
        #     "qat": True,
        #     "transfer_interval": 100,
        #     "world_model_cls": MuzeroWorldModel,
        #     "minibatch_size": 8,
        #     "training_steps": 1000,
        #     "min_replay_buffer_size": 100,
        #     "replay_buffer_size": 500,
        #     "games_per_generation": 1,
        #     "optimizer": torch.optim.Adam,
        #     "learning_rate": 0.001,
        #     "adam_epsilon": 1e-8,
        #     "weight_decay": 0,
        #     "momentum": 0.9,
        #     "clipnorm": 10,
        #     "training_iterations": 1,
        #     "num_minibatches": 1,
        #     "n_step": 5,
        #     "discount_factor": 0.997,
        #     "per_alpha": 1,
        #     "per_beta": 1,
        #     "per_epsilon": 1e-6,
        #     "per_use_batch_weights": False,
        #     "per_use_initial_max_priority": False,
        #     "lstm_horizon_len": 5,
        #     "value_prefix": True,
        #     "reanalyze_tau": 1,
        #     "lr_ratio": 10,
        #     "unroll_steps": 5,
        #     "reanalyze_ratio": 0.0,
        #     "projector_hidden_dim": 16,
        #     "predictor_hidden_dim": 16,
        #     "projector_output_dim": 16,
        #     "predictor_output_dim": 16,
        #     "num_simulations": 10,
        #     "root_dirichlet_alpha": 0.25,
        #     "root_exploration_fraction": 0.25,
        #     "residual_layers": [(16, 3, 1)] * 2,
        #     "conv_layers": [(16, 3, 1)],
        #     "dense_layer_widths": [],
        # }
        self.base_params = {
            "num_simulations": 25,
            "per_alpha": 0.0,
            "per_beta": 0.0,
            "per_beta_final": 0.0,
            "n_step": 10,
            "root_dirichlet_alpha": 0.25,
            "residual_layers": [(24, 3, 1)],
            "reward_dense_layer_widths": [],
            "reward_conv_layers": [(16, 1, 1)],
            "actor_dense_layer_widths": [],
            "actor_conv_layers": [(16, 1, 1)],
            "critic_dense_layer_widths": [],
            "critic_conv_layers": [(16, 1, 1)],
            "to_play_dense_layer_widths": [],
            "to_play_conv_layers": [(16, 1, 1)],
            "known_bounds": [-1, 1],
            "support_range": None,
            "minibatch_size": 8,
            "replay_buffer_size": 10000,
            "gumbel": False,
            "gumbel_m": 16,
            "policy_loss_function": CategoricalCrossentropyLoss(),
            "training_steps": 100,  # Reduced for benchmark speed
            "transfer_interval": 100,
            "num_workers": 4,
            "world_model_cls": MuzeroWorldModel,
            "search_batch_size": 5,  # Iterative
            "use_virtual_mean": True,
            "virtual_loss": 3.0,
            "use_torch_compile": True,
            "use_mixed_precision": True,
            "use_quantization": True,
            "qat": True,
        }

    def run_benchmark(self, agent_cls, name, num_steps=400):
        print(f"\n--- Benchmarking {name} ---")

        # 1. Setup Config
        params = self.base_params.copy()
        params["training_steps"] = num_steps
        params["multi_process"] = True
        params["num_workers"] = 4

        # Create Config Object
        config = MuZeroConfig(params, self.game_config)

        # Create Environment
        env = self.game_config.make_env()

        # Instantiate Agent
        test_agents = [RandomAgent(), TicTacToeBestAgent()]
        try:
            agent = agent_cls(
                env=env,
                config=config,
                name=f"bench_{name.lower().replace(' ', '_')}",
                device=torch.device("cpu"),
                test_agents=test_agents,
            )

            # Override testing to avoid slowdowns
            agent.test_interval = 100000
            agent.checkpoint_interval = 100

            # Run Training
            print(f"  Starting training for {num_steps} steps...")
            start_time = time.time()

            agent.training_step = 0
            agent.train()

            end_time = time.time()
            duration = end_time - start_time
            print(f"  Finished. Time: {duration:.2f}s")
            return duration

        except Exception as e:
            print(f"  FAILED: {e}")
            import traceback

            traceback.print_exc()
            return None

In [3]:
bench = Benchmark()
results = []
print("Benchmark initialized.")

Benchmark initialized.


In [None]:
# Run Ray Benchmark
# Note: Ray initialization happens inside the agent if not already started.
ray_time = bench.run_benchmark(MuZeroRay, "MuZero (Ray)", num_steps=1000)
results.append({"Agent": "MuZero (Ray)", "Time (s)": ray_time})


--- Benchmarking MuZero (Ray) ---
Using default save_intermediate_weights     : False
Using         training_steps                : 1000
Using default adam_epsilon                  : 1e-08
Using default momentum                      : 0.9
Using default learning_rate                 : 0.001
Using default clipnorm                      : 0
Using default optimizer                     : <class 'torch.optim.adam.Adam'>
Using default weight_decay                  : 0.0
Using default num_minibatches               : 1
Using default training_iterations           : 1
Using default lr_schedule_type              : none
Using default lr_schedule_steps             : []
Using default lr_schedule_steps             : []
Using default lr_schedule_values            : []
Using         use_mixed_precision           : True
Using         use_torch_compile             : True
Using default compile_mode                  : reduce-overhead
Using         minibatch_size                : 8
Using         replay_buffe

2026-01-28 20:20:41,112	INFO worker.py:2007 -- Started a local Ray instance.


Max size: 10000
Initializing stat 'score' with subkeys None
Initializing stat 'policy_loss' with subkeys None
Initializing stat 'value_loss' with subkeys None
Initializing stat 'reward_loss' with subkeys None
Initializing stat 'to_play_loss' with subkeys None
Initializing stat 'cons_loss' with subkeys None
Initializing stat 'loss' with subkeys None
Initializing stat 'test_score' with subkeys ['score', 'max_score', 'min_score']
Initializing stat 'episode_length' with subkeys None
Initializing stat 'policy_entropy' with subkeys None
Initializing stat 'value_diff' with subkeys None
Initializing stat 'policy_improvement' with subkeys ['network', 'search']
Initializing stat 'root_children_values' with subkeys None
Initializing stat 'test_score_vs_random' with subkeys ['score', 'player_1_score', 'player_2_score', 'player_1_win%', 'player_2_win%']
Initializing stat 'test_score_vs_tictactoe_expert' with subkeys ['score', 'player_1_score', 'player_2_score', 'player_1_win%', 'player_2_win%']
  S

[36m(MuZeroWorker pid=6493)[0m   from pkg_resources import resource_stream, resource_exists


[36m(MuZeroWorker pid=6497)[0m Hidden state shape: (1, 24, 3, 3)
[36m(MuZeroWorker pid=6497)[0m Hidden state shape: (1, 24, 3, 3)
[36m(MuZeroWorker pid=6497)[0m encoder input shape (1, 18, 3, 3)
Size: 0
Size: 7
0
actions shape torch.Size([8, 5])
target value shape torch.Size([8, 6])
predicted values shape torch.Size([8, 6, 1])
target rewards shape torch.Size([8, 6])
predicted rewards shape torch.Size([8, 6, 1])
target to plays shape torch.Size([8, 6, 2])
predicted to_plays shape torch.Size([8, 6, 2])
masks shape torch.Size([8, 6]) torch.Size([8, 6])
actions tensor([[8, 6, 4, 2, 0],
        [6, 4, 2, 0, 0],
        [2, 0, 0, 4, 3],
        [0, 0, 7, 4, 3],
        [7, 2, 1, 8, 0],
        [1, 8, 0, 5, 0],
        [8, 0, 5, 0, 3],
        [0, 5, 0, 4, 3]])
target value tensor([[ 0.9606, -0.9703,  0.9801, -0.9900,  1.0000,  0.0000],
        [-0.9703,  0.9801, -0.9900,  1.0000,  0.0000,  0.0000],
        [-0.9900,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  0.000

[33m(raylet)[0m [2026-01-28 20:20:51,158 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.3703 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[36m(MuZeroWorker pid=6499)[0m   from pkg_resources import resource_stream, resource_exists[32m [repeated 3x across cluster][0m


Size: 754
Size: 764
Size: 772
Size: 779
Size: 789
Size: 797
Size: 807
Size: 815
Size: 821
Size: 828
Size: 837
Size: 847
Size: 857
Size: 865
Size: 874
[36m(MuZeroWorker pid=6497)[0m Worker 1: Compiling INT8 model...
[36m(MuZeroWorker pid=6494)[0m Hidden state shape: (1, 24, 3, 3)[32m [repeated 6x across cluster][0m
[36m(MuZeroWorker pid=6494)[0m encoder input shape (1, 18, 3, 3)[32m [repeated 3x across cluster][0m
Size: 884
Size: 890


[36m(MuZeroWorker pid=6493)[0m   if not check_min_max_valid(min_val, max_val):


Size: 898
Size: 907
Size: 916
Size: 924
Size: 934
Size: 942
Size: 950
Size: 959
Size: 965
Size: 971
Size: 981
Size: 990
Size: 997


[33m(raylet)[0m [2026-01-28 20:21:01,247 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.3701 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[36m(MuZeroWorker pid=6499)[0m   if not check_min_max_valid(min_val, max_val):[32m [repeated 3x across cluster][0m
[36m(pid=gcs_server)[0m [2026-01-28 20:21:10,187 E 6486 57593] (gcs_server) gcs_server.cc:303: Failed to establish connection to the event+metrics exporter agent. Events and metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
[36m(MuZeroWorker pid=6499)[0m   if not check_min_max_valid(min_val, max_val):[32m [repeated 4x across cluster][0m
[33m(raylet)[0m [2026-01-28 20:21:11,114 E 6490 57705] (raylet) main.cc:1032: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent stat

Training Interrupted by User
Shutting down workers...
All workers shut down.
Finished Training
Testing Player 0 vs Agent random
Player 0 prediction: (tensor([0.0800, 0.0400, 0.0800, 0.1200, 0.4000, 0.0400, 0.0400, 0.0400, 0.1600]), tensor([0.0800, 0.0400, 0.0800, 0.1200, 0.4000, 0.0400, 0.0400, 0.0400, 0.1600]), 0.1384068141974795, tensor(4), {'network_policy': tensor([0.1120, 0.0990, 0.1237, 0.1472, 0.1472, 0.0735, 0.1014, 0.0811, 0.1148]), 'network_value': 0.16412188112735748, 'search_policy': tensor([0.0800, 0.0400, 0.0800, 0.1200, 0.4000, 0.0400, 0.0400, 0.0400, 0.1600]), 'search_value': 0.1384068141974795, 'root_children_values': tensor([-0.1376,  0.0484, -0.1193, -0.1606, -0.2906, -0.0552, -0.1403, -0.0168,
        -0.1755])})
action: 4
Player 1 random action: 6
Player 0 prediction: (tensor([0.0800, 0.0400, 0.2000, 0.1200, 0.0000, 0.3200, 0.0000, 0.0400, 0.2000]), tensor([0.0800, 0.0400, 0.2000, 0.1200, 0.0000, 0.3200, 0.0000, 0.0400, 0.2000]), 0.19292533192578587, tensor(5), {'n

[33m(raylet)[0m [2026-01-28 20:21:41,570 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.3631 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.


  Finished. Time: 61.58s


[33m(raylet)[0m [2026-01-28 20:21:51,648 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.362 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:22:01,723 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.362 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:22:11,809 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.3618 GB; capacity: 460.432 GB. Object creation will fail if spilling is required.
[33m(raylet)[0m [2026-01-28 20:22:21,892 E 6490 57722] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2026-01-28_20-20-39_848521_6419 is over 95% full, available space: 20.3618 GB; capacity: 460.432 GB. O

In [None]:
print("\n=== RESULTS ===")
df = tabulate(results, headers="keys", tablefmt="pretty", floatfmt=".2f")
print(df)

if ray_time and mp_time:
    speedup = ray_time / mp_time
    print(f"\nTime Ratio (Ray / TorchMP): {speedup:.2f}x")
    if speedup > 1.0:
        print(f"TorchMP is {speedup:.2f}x FASTER than Ray")
    else:
        print(f"Ray is {1/speedup:.2f}x FASTER than TorchMP")