In [37]:
import platformdirs
from pathlib import Path
import torch
import numpy as np

from katago.train.load_model import load_model
from katago.game.data import load_sgf_moves_exn
from katago.game.gamestate import GameState
from katago.game.features import Features
from katago.game.board import Board


cache_dir = Path(platformdirs.user_cache_dir("katago-onnx"))

network_name = "kata1-b28c512nbt-adam-s11165M-d5387M"
model_path = cache_dir / network_name / "model.ckpt"

In [38]:
# Load the model from checkpoint
device = "cpu"
model, swa_model, _ = load_model(
    str(model_path),
    use_swa=False,
    device=device,
    pos_len=19,
    verbose=True,
)

model = model.eval()

In [39]:
target_move = 232

# Path to SGF
sgf_path = Path("../sgf/game-19-white-38.sgf")
# sgf_path = Path("../sgf/game-13-white-12.sgf")

# Load SGF
metadata, setup, moves, rules = load_sgf_moves_exn(str(sgf_path))
print(f"Loaded SGF with {len(moves)} moves.")

# Initialize GameState
board_size = metadata.size
rules = GameState.RULES_JAPANESE
gs = GameState(board_size, rules)

if target_move is None:
    target_move = len(moves)

# Play moves
print(f"Replaying up to move {target_move}...")
for i, (pla, loc) in enumerate(moves):
    if i >= target_move:
        break
    gs.play(pla, loc)

print(f"Replayed to move {len(gs.moves)}")

# Prepare features
features = Features(model.config, pos_len=19)
bin_input, global_input = gs.get_input_features(features)

# Convert to torch
bin_input_torch = torch.tensor(bin_input).to(device)
global_input_torch = torch.tensor(global_input).to(device)

Loaded SGF with 282 moves.
Replaying up to move 232...
Replayed to move 232


In [40]:
# Run model
with torch.no_grad():
    outputs = model(bin_input_torch, global_input_torch)

# Unpack outputs
if model.has_intermediate_head:
    main_outputs, intermediate_outputs = outputs
else:
    main_outputs = outputs[0]

(
    policy_logits,
    value_logits,
    miscvalue_logits,
    moremiscvalue_logits,
    ownership_logits,
    scoring_logits,
    futurepos_logits,
    seki_logits,
    scorebelief_logits,
) = main_outputs

# Process Policy
# policy_logits shape: [batch, num_policy_outputs, num_moves]
# We want the first policy output (index 0) and softmax over moves (dim 1)
policy_probs = torch.nn.functional.softmax(policy_logits[:, 0, :], dim=1).cpu().numpy()[0]

# Process Value (Winrate)
# value_logits shape is usually [batch, 4] -> [win, loss, no_result, ?] or similar depending on config
# But typically for KataGo:
# value_output is [batch, 4] corresponding to [win, loss, no_result, draw] or similar.

# Actually, let's check model_pytorch.py or just print shape.
value_probs = torch.nn.functional.softmax(value_logits, dim=1).cpu().numpy()[0]
winrate = value_probs[0] # Assuming index 0 is win for current player? Or black?
# KataGo usually outputs from perspective of current player to move.

policy_probs.shape, value_probs.shape, winrate

((362,), (3,), np.float32(0.00018661308))

In [41]:

current_player_name = "Black" if gs.board.pla == Board.BLACK else "White"
print(f"Current player to play: {current_player_name}")

Current player to play: Black


In [None]:
print(f"Policy shape: {policy_probs.shape}")
print(f"Value probs: {value_probs}")

# Get top moves
top_moves_idx = np.argsort(policy_probs)[::-1][:5]
print("Top moves:")
for idx in top_moves_idx:
    loc = features.tensor_pos_to_loc(idx, gs.board)
    if loc is None:
        move_str = "PASS"
    elif loc == gs.board.loc(-10, -10):  # Illegal/Offboard
        move_str = "Illegal"
    else:
        x = gs.board.loc_x(loc)
        y = gs.board.loc_y(loc)
        # Convert to SGF coordinates (e.g. A1, D4) or similar
        # KataGo board uses 0-indexed coordinates.
        # Let's just print (x,y) for now or convert to GTP-like string
        col_str = "ABCDEFGHJKLMNOPQRST"[x]
        row_str = str(y + 1)
        move_str = f"{col_str}{row_str}"

    print(f"  {move_str}: {policy_probs[idx]:.4f}")


Policy shape: (362,)
Value probs: [1.8661308e-04 9.9981338e-01 3.7821611e-08]
Top moves:
  N14: 0.2287
  L4: 0.1626
  N13: 0.1251
  M3: 0.0615
  T13: 0.0531
