In [1]:
import platformdirs
from pathlib import Path
import numpy as np

import onnxruntime

from katago.game.board import Board

from katago_onnx.utils import load_model
from katago_onnx.utils import load_sgf
from katago_onnx.utils import featurize

cache_dir = Path(platformdirs.user_cache_dir("katago-onnx"))

# Load model
network_name = "kata1-b28c512nbt-adam-s11165M-d5387M"
model_path = cache_dir / network_name / "model.ckpt"
model = load_model(model_path, device="cpu")

onnx_model_path = cache_dir / network_name / "model.onnx"

In [2]:
target_move = 11

# Path to SGF
sgf_path = "../sgf/game-19-white-38.sgf"
# sgf_path = "../sgf/game-13-white-12.sgf"
# sgf_path = "../sgf/game-9-white-19.sgf"

# Load game state from SGF
gamestate = load_sgf(sgf_path, target_move=target_move)

# Featurize the game state
features, bin_input, global_input = featurize(gamestate, model)

current_player_name = "Black" if gamestate.board.pla == Board.BLACK else "White"
print(f"Current player to play: {current_player_name}")

bin_input.shape, global_input.shape

Loaded SGF with 282 moves.
Game size: 19
Replayed to move 11
Current player to play: White


(torch.Size([1, 22, 19, 19]), torch.Size([1, 19]))

In [3]:
target_move = 8

# Path to SGF
sgf_path = "../sgf/game-19-white-38.sgf"
sgf_path = "../sgf/game-13-white-12.sgf"
# sgf_path = "../sgf/game-9-white-19.sgf"

# Load game state from SGF
gamestate = load_sgf(sgf_path, target_move=target_move)

# Featurize the game state
features, bin_input, global_input = featurize(gamestate, model)

# Prepare  inputs for ONNX export
model_inputs = (bin_input, global_input)
onnx_inputs = [tensor.numpy(force=True) for tensor in model_inputs]

ort_session = onnxruntime.InferenceSession(
    onnx_model_path,
    providers=["CPUExecutionProvider"],
)

onnxruntime_input = {
    input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)
}

# ONNX Runtime returns a list of outputs
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

# Unpack outputs
policy_logits = np.array(onnxruntime_outputs[0])
value_logits = np.array(onnxruntime_outputs[1])

# Process Policy
# policy_logits is already numpy from ONNX Runtime
# Shape: [batch, num_policy_outputs, num_moves]
# We want the first policy output (index 0) and softmax over moves (axis 2)
policy_probs = np.exp(policy_logits[:, 0, :]) / np.sum(np.exp(policy_logits[:, 0, :]), axis=1, keepdims=True)
policy_probs = policy_probs[0]

# Process Value (Winrate)
# value_logits is already numpy from ONNX Runtime
# Shape: [batch, 3] corresponding to [win, loss, no_result]
value_probs = np.exp(value_logits) / np.sum(np.exp(value_logits), axis=1, keepdims=True)
value_probs = value_probs[0]

winrate = value_probs[0]  # Assuming index 0 is win for current player? Or black? Or white?

# Process Score Lead
# miscvalue_logits is at index 2
miscvalue_logits = np.array(onnxruntime_outputs[2])
# Index 2 of miscvalue is lead
score_lead = miscvalue_logits[0, 2] * model.lead_multiplier

print(f"Winrate: {winrate}")
print(f"Score Lead: {score_lead}")

for a in onnxruntime_outputs:
    print(a.shape)

Loaded SGF with 117 moves.
Game size: 13
Replayed to move 8
Winrate: 0.5851287245750427
Score Lead: 0.2296016365289688
(1, 6, 170)
(1, 3)
(1, 10)
(1, 8)
(1, 1, 13, 13)
(1, 1, 13, 13)
(1, 2, 13, 13)
(1, 4, 13, 13)
(1, 842)
(1, 6, 170)
(1, 3)
(1, 10)
(1, 8)
(1, 1, 13, 13)
(1, 1, 13, 13)
(1, 2, 13, 13)
(1, 4, 13, 13)
(1, 842)


In [18]:
winrate

np.float32(0.5851287)