使用训练模型，观察搜索树

In [1]:
from muzero.feature_utils import obs2feature
from muzero.mcts import MCTS, render_root
import gymnasium as gym
import gymxq


In [2]:
import torch


In [3]:
from muzero.models import MuZeroNetwork
from muzero.config import MuZeroConfig
import numpy as np
from muzero.feature_utils import encoded_action

In [4]:
config = MuZeroConfig()
config.batch_size = 512
config.training_steps = 200
config.num_simulations = 120


In [5]:
model = MuZeroNetwork(config)
model.load_state_dict(torch.load("model_weights.pth"))
model = model.to("cuda")
model = model.eval()


In [6]:
# init_fen = "3k5/2P1P4/9/9/9/9/9/9/4p1p2/5K3 r - 100 0 190"
init_fen = "2r2k3/6R1C/b4N1rb/9/5n3/5C3/6n2/5p3/4p4/5K1R1 r - 110 0 180"


In [7]:
env = gym.make(
    "xqv1",
    init_fen=init_fen,
    render_mode="ansi",
)


In [8]:
obs, info = env.reset()
print(env.render())



9 [30m＋[0m[30m＋[0m[30m＋[0m[34m将[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
8 [30m＋[0m[30m＋[0m[31m兵[0m[30m＋[0m[31m兵[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
7 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
6 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
5 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
4 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
3 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
2 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m
1 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[34m卒[0m[30m＋[0m[34m卒[0m[30m＋[0m[30m＋[0m
0 [30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[30m＋[0m[31m帅[0m[30m＋[0m[30m＋[0m[30m＋[0m
  ０１２３４５６７８
轮到红方走子



In [9]:
# 设置调试 mcts 搜索树
config.debug_mcts = True


In [10]:
observation = obs2feature(obs, flatten=False)
last_a = 2086
observation = np.concatenate(
    [encoded_action(last_a)[np.newaxis, :], observation], axis=1
)

to_play = info["to_play"]
reset = False
with torch.no_grad():
    legal_actions = info["legal_actions"]
    root, mcts_info = MCTS(config).run(
        model,
        observation,
        legal_actions,
        to_play,
        False,
    )
    render_root(root, "test", "svg", "mcts_tree")


In [11]:
root.value()

0.9456578975193672

In [12]:
root.get_updated_policy()

{'2818': 0.0,
 '2838': 0.0,
 '2829': 0.99,
 '4838': 0.01,
 '4858': 0.0,
 '4849': 0.0}