In [4]:
from typing import Generic, Optional, NamedTuple, Callable, Union
from typing import Generic, TypeVar, Protocol
from abc import ABC, abstractmethod

State = TypeVar("State")
Action = TypeVar("Action")
Example = TypeVar("Example")
Trace = tuple[list[State], list[Action]]
Args = TypeVar("Args")


class WorldModel(ABC, Generic[State, Action, Example]):
    def __init__(self) -> None:
        self.example = None

    @abstractmethod
    def init_state(self) -> State: ...

    @abstractmethod
    def step(self, state: State, action: Action) -> State: ...

    @abstractmethod
    def is_terminal(self, state: State) -> bool: ...

    def update_example(self, example: Example) -> None:
        self.example = example


class SearchConfig(ABC, Generic[State, Action, Example]):
    def __init__(self) -> None:
        self.example = None

    @abstractmethod
    def get_actions(self, state: State) -> list[Action]: ...

    @abstractmethod
    def reward(self, state, action, **kwargs) -> tuple[float, dict]: ...
    
    @abstractmethod
    def get_values(self, state: State, action: Action) -> list[tuple[float, bool]]: ...

    def update_example(self, example: Example) -> None:
        self.example = example


class HasTerminalStateAndTrace(Protocol[State]):
    terminal_state: State
    trace: Trace


class SearchAlgorithm(ABC):
    def __init__(self, **kwargs): ...

    @abstractmethod
    def __call__(self, world_model: WorldModel, search_config: SearchConfig, **kwargs) -> HasTerminalStateAndTrace: ...


class TreeConstructor(ABC, Generic[State, Action, Example]):
    def __init__(self,
                 world_model: WorldModel[State, Action, Example],
                 search_config: SearchConfig[State, Action, Example],
                 search_algo: SearchAlgorithm) -> None:
        self.world_model = world_model
        self.search_config = search_config
        self.search_algo = search_algo

    def __call__(self, example: Example, node=None, **kwargs) -> HasTerminalStateAndTrace[State]:
        self.world_model.update_example(example)
        self.search_config.update_example(example)
        return self.search_algo(self.world_model, 
                                self.search_config, 
                                root_node=node, 
                                **kwargs)

In [18]:
import itertools
from tqdm import trange
from copy import deepcopy
import torch
import numpy as np
import math

In [15]:
def calculate_diversity_score(candidates):
    if candidates is None: return 0
    
    Q_values = [sample.Q for sample in candidates]
    variance = np.var(np.asarray(Q_values))
    gap = max(Q_values) - min(Q_values)
    # return gap if max(Q_values) > 0 else gap * 0.5
    
    visit_counts = [sample.N for sample in candidates]
    gap = max(visit_counts) - min(visit_counts)
    return gap

In [None]:


class MCTSNode(Generic[State, Action]):
    id_iter = itertools.count()

    @classmethod
    def reset_id(cls):
        cls.id_iter = itertools.count()

    def __init__(
        self, 
        state: Optional[State], 
        action: Optional[Action], 
        parent: "Optional[MCTSNode]" = None,
        base_rewards: torch.Tensor = None, 
        value: float = 0.0, 
        embeddings: torch.Tensor = None, 
        log_probs: torch.Tensor = None,
        ref_log_probs: torch.Tensor = None,
        is_terminal: bool = False,
        length_penalty: float = 1.25,
    ):
        """
        A node in the MCTS search tree

        :param state: the current state
        :param action: the action of the last step, i.e., the action from parent node to current node
        :param parent: the parent node, None if root of the tree
        :param embeddings: the embeddings of the current state (BERTScore calculation for similar generations filtering)
        :param is_terminal: whether the current state is a terminal state
        
        :param rewards: base rewards
        :param value: advantage of taking the action
        """
        self.id = next(MCTSNode.id_iter)
        self.is_terminal = is_terminal
        self.state = state
        self.action = action
        self.parent = parent
        self.embeddings = embeddings
        self.children: 'Optional[list[MCTSNode]]' = None
        self.depth = 0 if parent is None else parent.depth + 1
        self.length_penalty = length_penalty
        
        self.rewards = base_rewards
        self.log_probs = log_probs
        self.ref_log_probs = ref_log_probs
        self.value = value
        
        self.N = 0
        self.V = 0.0
        self.Q = self.parent.V + self.r if self.parent is not None else self.r

    @property
    def r(self) -> float:
        if self.rewards is None:
            return self.value if self.parent is None else (self.value - self.parent.value)
        
        # 默认处理：直接用 rewards.mean()
        return self.rewards.mean().item()
        
    @property
    def p(self) -> float:
        return 1#(self.log_probs.sum() / self.log_probs.size(-1) ** self.length_penalty).exp().detach().item()
    
class MCTSResult(NamedTuple):
    tree_state: MCTSNode
    next_action_pi: list[float]
    next_action_V: list[float]
    next_action_Q: list[float]
    trace_in_each_iter: list[list[MCTSNode]] = None
    next_action_idx: int = 0
    trace_of_nodes: list[MCTSNode] = None
    cum_reward: float = None


class MCTS(SearchAlgorithm, Generic[State, Action]):
    def __init__(self, args: MCTSConfig):
        """
        MCTS algorithm
        """
        super().__init__()
        self.world_model = None
        self.search_config = None
        self.output_trace_in_each_iter = args.output_trace_in_each_iter
        self.w_exp = args.w_exp
        self.depth_limit = args.depth_limit
        self.breadth_limit = args.breadth_limit
        self.n_iters = args.n_iters
        self.gamma = args.gamma
        self.add_kl = args.add_kl
        default_simulate_strategies: dict[str, Callable[[list[float]], int]] = {
            'max': lambda x: np.argmax(x),
            'sample': lambda x: np.random.choice(len(x), p=x),
            'random': lambda x: np.random.choice(len(x)),
        }
        self.simulate_choice: Callable[[list[float]], int] = default_simulate_strategies.get(args.simulate_strategy,
                                                                                             args.simulate_strategy)
        self.temperature = args.temperature
        self.temperature_decay_ratio = args.temperature_decay_ratio
        self.follow_probability = False
        self._output_iter: list[MCTSNode] = None
        self._output_cum_reward = -math.inf
        self.trace_in_each_iter: list[list[MCTSNode]] = None
        self.root: Optional[MCTSNode] = None
        self.disable_tqdm = args.disable_tqdm
        self.consider_diversity = args.consider_diversity
        self.length_penalty = args.length_penalty
        
        self.policy_model = None

    def _get_simulated_pi(self, cur_node: MCTSNode, return_selection=False) -> list[float]:
        """
        Apated from: https://github.com/suragnair/alpha-zero-general/blob/ce020c8eebbabf0e22654279508a6887b4791015/MCTS.py#L28C5-L53C21
        """
        visit_counts = [child.N for child in cur_node.children]
        next_action_V = [child.V for child in cur_node.children]
        next_action_Q = [child.Q for child in cur_node.children]
        next_action_n_children = [len(child.children) if child.children is not None else 0 for child in cur_node.children]
        next_action_variance = [calculate_diversity_score(child.children) for child in cur_node.children]
        
        def _cal_probs(temp):
            if temp > 0:
                try:
                    ## choice 1: to sample based on visit counts
                    # counts = [(x * (nc + 1 if self.consider_diversity else 1)) ** (1. / temp) if x else x \
                    #     for x, nc in zip(visit_counts, next_action_n_children)]
                    ## choice 2: to sample based on Q values
                    counts = [(math.exp(x) * (nc + 1 if self.consider_diversity else 1)) ** (1. / temp) if x else x \
                        for x, nc in zip(next_action_Q, next_action_n_children)]
                    total_count = float(sum(counts))
                    probs = [x / total_count for x in counts]
                    return probs
                except OverflowError as e:
                    print(('Run into {} -- Temperature too small ... Set to zero ...').format(str(e)))
            best_actions = np.array(np.argwhere(visit_counts == np.max(visit_counts))).flatten()
            probs = [0] * len(visit_counts)
            for best_action in best_actions:
                probs[best_action] = 1 / len(best_actions)
            return probs
        
        temperature = self.temperature * (self.temperature_decay_ratio ** cur_node.depth)
        probs = _cal_probs(temperature)
        
        if return_selection:
            if temperature == 0:
                ## choice 1: to sample based on visit counts
                # selected_idx = max(range(len(visit_counts)), key=lambda x: (
                #     (next_action_Q[x] + 2) * (next_action_variance[x] + 1 if self.consider_diversity else 1), 
                #     visit_counts[x], next_action_V[x]
                # ))
                ## choice 2: to sample based on Q values
                selected_idx = max(range(len(visit_counts)), key=lambda x: (
                    visit_counts[x] * (next_action_variance[x] + 1 if self.consider_diversity else 1), 
                    next_action_Q[x], next_action_V[x]
                ))
            else:
                selected_idx = np.random.choice(range(len(visit_counts)), p=probs)
            return probs, selected_idx, next_action_V, next_action_Q
        return probs, next_action_V, next_action_Q
    
    def iterate(self, node: MCTSNode) -> list[MCTSNode]:
        node.N += 1
        path = self._select(node)
        while not self._is_terminal_with_depth_limit(path[-1]):
            self._expand_and_evaluate(path[-1])
            # ### debug mode
            # if path[-1].parent is not None:
            #     self._back_propagate(path)
            if self._is_terminal_with_depth_limit(path[-1]) or len(path[-1].children) == 0:
                break
            node = self._puct_select(path[-1])
            path.append(node)
        self._back_propagate(path)
        return path

    def _is_terminal_with_depth_limit(self, node: MCTSNode):
        return node.is_terminal or (node.depth - self.root.depth) >= self.depth_limit

    def _select(self, node: MCTSNode) -> list[MCTSNode]:
        path = []
        while True:
            path.append(node)
            if node.children is None or len(node.children) == 0 or self._is_terminal_with_depth_limit(node):
                return path
            node = self._puct_select(node)

    def _puct(self, node: MCTSNode) -> float:
        return node.Q + self.w_exp * node.p * np.sqrt(node.parent.N) / (1 + node.N)
    
    def _puct_select(self, node: MCTSNode) -> MCTSNode:
        xnode = max(node.children, key=self._puct)
        return xnode

    def _expand_and_evaluate(self, node: MCTSNode):
        if node.state is None:
            node.state = self.world_model.step(node.parent.state, node.action, node.log_probs)
            node.is_terminal = self.world_model.is_terminal(node.state)
        
        if node.is_terminal:
            return
        
        actions = self.search_config.get_actions(self.policy_model, node.state, add_kl=self.add_kl)
        
        action_batch, log_probs_batch, ref_log_probs_batch = [], [], []
        for action, (log_probs, ref_log_probs), _ in actions:
            action_batch.append(action)
            log_probs_batch.append(log_probs)
            ref_log_probs_batch.append(ref_log_probs)
        reward_value_batch = self.search_config.get_values(self.policy_model, node.state, action_batch, 
                                                           log_probs_batch, ref_log_probs_batch, 
                                                           add_kl=self.add_kl, parent_depth=node.depth,
                                                           parent_value=node.value)

        children = []
        for (action, (log_probs, ref_log_probs), embs), (value, base_rewards, is_terminal) in zip(actions, reward_value_batch):
            child = MCTSNode(state=None, action=action, parent=node, 
                             base_rewards=base_rewards, value=value, 
                             embeddings=embs, log_probs=log_probs, ref_log_probs=ref_log_probs,
                             is_terminal=is_terminal, length_penalty=self.length_penalty)
            children.append(child)
        node.children = children if node.children is None else node.children + children

    def _simulate(self, path: list[MCTSNode]):
        node = path[-1]
        while True:
            if node.state is None:
                self._expand(node)
            if self._is_terminal_with_depth_limit(node) or len(node.children) == 0:
                return
            fast_rewards = [child.fast_reward for child in node.children]
            node = node.children[self.simulate_choice(fast_rewards)]
            path.append(node)

    def _back_propagate(self, path: list[MCTSNode]):
        node = path[-1]
        node.Q = node.r + self.gamma * node.V
        node.N += 1
        for node in reversed(path[:-1]):
            node.V = sum(max(1, child.N) * child.Q for child in node.children) / sum(max(1, child.N) for child in node.children)
            node.N += 1
            if node.action is not None:
                node.Q = node.r + self.gamma * node.V

    def search(self):
        if self.root is None:
            self.root = MCTSNode(state=self.world_model.init_state(), action=None, parent=None, length_penalty=self.length_penalty)
        if self.output_trace_in_each_iter:
            self.trace_in_each_iter = []

        n_iters = self.n_iters if self.root.depth else self.n_iters * 4     # iterate more at the starting point
        for _ in trange(n_iters, disable=self.disable_tqdm, desc='MCTS iteration', leave=False):
            path = self.iterate(self.root)
            if self.output_trace_in_each_iter:
                self.trace_in_each_iter.append(deepcopy(path))

    def __call__(self,
                 world_model: WorldModel[State, Action, Example],
                 search_config: SearchConfig[State, Action, Example],
                 root_node: Optional[Union[MCTSNode, int]] = None,
                 **kwargs) -> MCTSResult:
        if root_node is None:
            MCTSNode.reset_id()
            
        self.root = root_node
        self.world_model = world_model
        self.search_config = search_config
        self.consider_diversity = False if self.search_config.n_actions == 1 else self.consider_diversity

        self.search()
        
        if self.output_trace_in_each_iter:
            trace_in_each_iter = self.trace_in_each_iter
        else:
            trace_in_each_iter = None
        
        next_action_pi, selected_idx, next_action_V, next_action_Q = self._get_simulated_pi(self.root, return_selection=True)
        
        return MCTSResult(tree_state=self.root,
                          next_action_pi=next_action_pi,
                          next_action_V=next_action_V,
                          next_action_Q=next_action_Q,
                          trace_in_each_iter=trace_in_each_iter,
                          next_action_idx=selected_idx)

In [None]:
import numpy as np

class TicTacToeState:
    def __init__(self, board=None, to_play=1):
        # board: 3x3, 0-empty, 1-X, -1-O
        self.board = board if board is not None else np.zeros((3,3), dtype=int)
        self.to_play = to_play  # 1 or -1
        
    def __str__(self):
        symbols = {0: '.', 1: 'X', -1: 'O'}
        rows = []
        for row in self.board:
            rows.append(' '.join(symbols[val] for val in row))
        return '\n'.join(rows)
    
    def get_legal_actions(self):
        return [i for i in range(9) if self.board.flat[i] == 0]

    def perform_action(self, action):
        new_board = self.board.copy()
        new_board.flat[action] = self.to_play
        return TicTacToeState(new_board, -self.to_play)

    def is_game_over(self):
        # 判断是否有人获胜或平局
        lines = list(self.board) + list(self.board.T) + \
                [self.board.diagonal(), np.fliplr(self.board).diagonal()]
        for line in lines:
            s = sum(line)
            if s == 3 or s == -3:
                return True
        if not self.get_legal_actions():
            return True
        return False

    def get_winner(self):
        # 返回赢家，1，-1，或者0平局
        lines = list(self.board) + list(self.board.T) + \
                [self.board.diagonal(), np.fliplr(self.board).diagonal()]
        for line in lines:
            s = sum(line)
            if s == 3:
                return 1
            elif s == -3:
                return -1
        if not self.get_legal_actions():
            return 0  # 平局
        return None  # 未结束
    
class TicTacToeWorldModel:
    def init_state(self):
        return TicTacToeState()

    def step(self, state, action, log_probs=None):
        return state.perform_action(action)

    def is_terminal(self, state):
        return state.is_game_over()
    

import torch

class TicTacToeSearchConfig:
    def __init__(self):
        self.n_actions = 9  # 最大动作数
        self.n_init_actions = 9
        self.temperature = 1.0

    def get_actions(self, policy_model, state, add_kl=False):
        legal = state.get_legal_actions()
        actions = []
        for a in legal:
            # 伪造 log_probs = 0，表示无偏好（均匀随机）
            log_prob = torch.tensor(0.0)
            ref_log_prob = torch.tensor(0.0)
            emb = None
            actions.append((a, (log_prob, ref_log_prob), emb))
        return actions

    def get_values(self, policy_model, state, actions, log_probs, ref_log_probs, add_kl, parent_depth, parent_value):
        values = []
        for action, lp, rlp in zip(actions, log_probs, ref_log_probs):
            next_state = state.perform_action(action)
            winner = next_state.get_winner()
            if winner is None:
                value = 0.0
                done = False
            else:
                value = 1.0 if winner == 1 else -1.0 if winner == -1 else 0.0
                done = True
            values.append((value, torch.tensor([value]), done))
        return values

class DummyPolicyModel:
    def predict_log_probs(self, state, legal_moves):
        # 返回均匀概率
        n = len(legal_moves)
        prob = 1 / n if n > 0 else 0
        return torch.log(torch.tensor([prob] * n))
    
    def evaluate_state(self, state):
        # 返回0，表示状态无价值估计
        return 0.0

In [98]:
class MCTSConfig(NamedTuple):
    output_trace_in_each_iter: bool = True
    w_exp: float = 1.
    depth_limit: int = 10
    breadth_limit: int = 3
    n_iters: int = 20
    simulate_strategy: str | Callable[[list[float]], int] = 'max'
    disable_tqdm: bool = True
    temperature: float = 0.0
    temperature_decay_ratio: float = 0.75
    gamma: float = 1.0
    add_kl: bool = False
    consider_diversity: bool = True
    length_penalty: float = 1.25

In [99]:
args = MCTSConfig()

In [None]:
world_model = TicTacToeWorldModel()
search_config = TicTacToeSearchConfig()
policy_model = DummyPolicyModel()
mcts = MCTS(args)  #
mcts.world_model = world_model
mcts.search_config = search_config
mcts.policy_model = policy_model
#mcts.root_node = MCTSNode(state=world_model.init_state(), action=None, parent=None, length_penalty=args.length_penalty)
# 执行搜索
result = mcts.search()

In [102]:
next_action_pi, selected_idx, next_action_V, next_action_Q = mcts._get_simulated_pi(mcts.root, return_selection=True)

In [103]:
for i in range(len(mcts.trace_in_each_iter)):
    for j in range(len(mcts.trace_in_each_iter[i])):
        print(f"Iteration {i}, Node {j}:")
        print(mcts.trace_in_each_iter[i][j].state)
        print("-----------------------")

Iteration 0, Node 0:
. . .
. . .
. . .
-----------------------
Iteration 0, Node 1:
X . .
. . .
. . .
-----------------------
Iteration 0, Node 2:
X O .
. . .
. . .
-----------------------
Iteration 0, Node 3:
X O X
. . .
. . .
-----------------------
Iteration 0, Node 4:
X O X
O . .
. . .
-----------------------
Iteration 0, Node 5:
X O X
O X .
. . .
-----------------------
Iteration 0, Node 6:
X O X
O X O
. . .
-----------------------
Iteration 0, Node 7:
None
-----------------------
Iteration 1, Node 0:
. . .
. . .
. . .
-----------------------
Iteration 1, Node 1:
. X .
. . .
. . .
-----------------------
Iteration 1, Node 2:
O X .
. . .
. . .
-----------------------
Iteration 1, Node 3:
O X X
. . .
. . .
-----------------------
Iteration 1, Node 4:
O X X
O . .
. . .
-----------------------
Iteration 1, Node 5:
O X X
O X .
. . .
-----------------------
Iteration 1, Node 6:
O X X
O X O
. . .
-----------------------
Iteration 1, Node 7:
None
-----------------------
Iteration 2, Node 

In [129]:
print(mcts.trace_in_each_iter[76][6].state)

O O X
X O .
. X .


In [127]:
next_action_pi, selected_idx, next_action_V, next_action_Q = mcts._get_simulated_pi(mcts.trace_in_each_iter[76][6], return_selection=True)

In [131]:
legal_actions = mcts.trace_in_each_iter[76][6].state.get_legal_actions()
legal_actions[selected_idx]

5

In [130]:
legal_actions

[2, 3, 5, 6, 7, 8]

In [126]:
next_action_pi

[1.0, 0, 0, 0, 0, 0]

In [114]:
next_action_pi

[0, 0, 0, 0, 1.0]

In [113]:
next_action_Q

[0.0, 0.0, 0.0, 0.0, 1.0]

In [None]:
print(mcts.root.state)

. . .
. . .
. . .


In [None]:
simulate_choice

In [75]:
next_action_pi

[0, 0, 0, 0.25, 0.25, 0, 0.25, 0, 0.25]

In [70]:
selected_idx

4

In [71]:
next_action_V

[0.0,
 0.0,
 0.0,
 0.0023809523809523807,
 0.002976190476190476,
 0.0,
 0.0017857142857142854,
 0.0005952380952380952,
 0.0017857142857142854]

In [62]:
next_action_Q

[0.0,
 0.0,
 0.0,
 0.0023809523809523807,
 0.002976190476190476,
 0.0,
 0.0017857142857142854,
 0.0005952380952380952,
 0.0017857142857142854]

In [None]:

mcts = MCTS(world_model=world_model, search_config=search_config, policy_model=policy_model)
result = mcts.search()

print("Best action:", result.next_action_idx)


In [22]:
class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visit_count = 0
        self.total_value = 0.0

    def __repr__(self):
        return f"MCTSNode(state={self.state})"

# 构造一些伪节点用于测试
root = MCTSNode("root")
child1 = MCTSNode("child1", root)
child2 = MCTSNode("child2", root)
root.children = [child1, child2]

# 构造 MCTSResult 实例
result = MCTSResult(
    tree_state=root,
    next_action_pi=[0.7, 0.3],
    next_action_V=[0.9, 0.8],
    next_action_Q=[1.0, 0.6],
    trace_in_each_iter=[[root, child1], [root, child2]],
    next_action_idx=0,
    trace_of_nodes=[root, child1],
    cum_reward=1.5
)

# 打印测试结果
print("MCTS搜索结果：")
print(f"当前树状态：{result.tree_state}")
print(f"下一步动作概率 pi：{result.next_action_pi}")
print(f"下一步状态值 V：{result.next_action_V}")
print(f"下一步Q值 Q：{result.next_action_Q}")
print(f"选择的动作索引：{result.next_action_idx}")
print(f"模拟路径：{result.trace_of_nodes}")
print(f"模拟累计奖励：{result.cum_reward}")
print(f"每次模拟路径跟踪：{result.trace_in_each_iter}")

MCTS搜索结果：
当前树状态：MCTSNode(state=root)
下一步动作概率 pi：[0.7, 0.3]
下一步状态值 V：[0.9, 0.8]
下一步Q值 Q：[1.0, 0.6]
选择的动作索引：0
模拟路径：[MCTSNode(state=root), MCTSNode(state=child1)]
模拟累计奖励：1.5
每次模拟路径跟踪：[[MCTSNode(state=root), MCTSNode(state=child1)], [MCTSNode(state=root), MCTSNode(state=child2)]]
