In [2]:
# load the model 
# pass something in 
# visualize the decision tree 
# find good and bad paths 
# create the seperator between the good path and all the bad paths 
# need to pull activations 


# open questions 
# (1) at what depth do the decisions start mattering 

In [5]:
# Step 1: Setup — board state + MCTS (with or without net)
import numpy as np
from alphazero.games.othello import OthelloBoard, OthelloNet
from alphazero.games.registers import NETWORKS_REGISTER
from alphazero.players import AlphaZeroPlayer, MCTSPlayer

board_size = 6
board = OthelloBoard(n=board_size)

# Optional: start from a specific position by playing some moves
# board.play_move((2, 3))
# board.play_move((2, 2))
# ... or leave as initial position

# Load net and create player (use AlphaZero for neural eval, or MCTSPlayer(n_sim=100) for rollout-only)
try:
    net = NETWORKS_REGISTER["othello"].from_pretrained("alphazero-othello", verbose=False)
    player = AlphaZeroPlayer(nn=net, n_sim=100)
except Exception:
    player = MCTSPlayer(n_sim=100)  # fallback: no net

# Snapshot the board *before* search so we have the root state for the decision path
board_snapshot = board.clone()

In [6]:
# Step 2: Run MCTS on the board (this fills the tree)
move, action_probs, visit_counts, prior_probs = player.get_move(board)
mct = player.mct

print(f"Root had {len(mct.root.children)} children. Best move chosen: {move}")
print(f"Visit counts at root: {visit_counts}")

Root had 4 children. Best move chosen: (3, 1)
Visit counts at root: {(3, 1): 41, (2, 4): 24, (1, 3): 1, (4, 2): 34}


In [7]:
# Step 3: Walk down the tree — at each step note best move (chosen) and second-best (disregarded)
def get_decision_path(mct, board=None, max_depth=None, sort_key="N"):
    """Walk from root along most-visited path; at each step return best and second-best move."""
    node = mct.root
    path = []
    depth = 0
    board_copy = board.clone() if board is not None else None

    while node is not None:
        if max_depth is not None and depth >= max_depth:
            break
        step = {"depth": depth}
        if board_copy is not None:
            step["grid"] = board_copy.grid.copy()
            step["player"] = board_copy.player

        if not node.children:
            break

        # Sort children by visit count (or Q) descending
        items = list(node.children.items())
        items.sort(key=lambda x: getattr(x[1], sort_key), reverse=True)

        best_move, best_node = items[0]
        step["best_move"] = best_move
        step["best_N"] = best_node.N
        step["best_Q"] = best_node.Q
        step["best_P"] = best_node.P

        if len(items) >= 2:
            second_move, second_node = items[1]
            step["second_best_move"] = second_move
            step["second_best_N"] = second_node.N
            step["second_best_Q"] = second_node.Q
            step["second_best_P"] = second_node.P
        else:
            step["second_best_move"] = None
            step["second_best_N"] = step["second_best_Q"] = step["second_best_P"] = None

        path.append(step)
        if board_copy is not None and best_move is not None:
            board_copy.play_move(best_move)
        node = best_node
        depth += 1

    return path

decisions = get_decision_path(mct, board=board_snapshot)

In [8]:
# Step 4: Inspect decisions at each timestep
for step in decisions:
    d = step["depth"]
    best = step["best_move"]
    best_N, best_Q = step["best_N"], step["best_Q"]
    second = step.get("second_best_move")
    second_N = step.get("second_best_N")
    second_Q = step.get("second_best_Q")
    s_N = f"N={second_N:.0f}" if second_N is not None else "N=—"
    s_Q = f"Q={second_Q:.3f}" if second_Q is not None else "Q=—"
    print(f"depth {d}: best={best} (N={best_N:.0f}, Q={best_Q:.3f})  |  second_best={second} ({s_N}, {s_Q})")

depth 0: best=(3, 1) (N=41, Q=0.415)  |  second_best=(4, 2) (N=34, Q=0.436)
depth 1: best=(4, 3) (N=38, Q=-0.404)  |  second_best=(4, 1) (N=1, Q=-0.959)
depth 2: best=(2, 4) (N=37, Q=0.402)  |  second_best=(4, 4) (N=0, Q=0.000)
depth 3: best=(3, 0) (N=23, Q=-0.445)  |  second_best=(1, 3) (N=10, Q=-0.168)
depth 4: best=(4, 2) (N=22, Q=0.445)  |  second_best=(4, 4) (N=0, Q=0.000)
depth 5: best=(2, 1) (N=17, Q=-0.406)  |  second_best=(3, 4) (N=1, Q=-0.613)
depth 6: best=(5, 3) (N=16, Q=0.391)  |  second_best=(4, 4) (N=0, Q=0.000)
depth 7: best=(1, 2) (N=8, Q=-0.502)  |  second_best=(1, 3) (N=2, Q=-0.035)
depth 8: best=(2, 0) (N=7, Q=0.498)  |  second_best=(0, 1) (N=0, Q=0.000)
depth 9: best=(1, 0) (N=4, Q=-0.492)  |  second_best=(3, 4) (N=1, Q=-0.028)
depth 10: best=(4, 1) (N=3, Q=0.452)  |  second_best=(0, 1) (N=0, Q=0.000)
depth 11: best=(2, 5) (N=2, Q=-0.416)  |  second_best=(3, 4) (N=0, Q=0.000)
depth 12: best=(0, 2) (N=1, Q=0.358)  |  second_best=(1, 5) (N=0, Q=0.000)
