In [1]:
import platformdirs
from pathlib import Path
import torch
import numpy as np

import onnxruntime

from katago_onnx.utils import load_model
from katago_onnx.utils import load_sgf
from katago_onnx.utils import featurize
from katago_onnx.utils import run_inference
from katago_onnx.utils import get_top_moves

cache_dir = Path(platformdirs.user_cache_dir("katago-onnx"))

# Load model
network_name = "kata1-b28c512nbt-adam-s11165M-d5387M"
model_path = cache_dir / network_name / "model.ckpt"
model = load_model(model_path, device="cpu")

In [None]:
target_move = 8

# Path to SGF
sgf_path = "../sgf/game-19-white-38.sgf"
sgf_path = "../sgf/game-13-white-12.sgf"
sgf_path = "../sgf/game-9-white-19.sgf"

# Load game state from SGF
gamestate = load_sgf(sgf_path, target_move=target_move)

# Featurize the game state
features, bin_input, global_input = featurize(gamestate, model)

Loaded SGF with 117 moves.
Game size: 13
Replayed to move 8


In [6]:
# Prepare  inputs for ONNX export
model_inputs = (bin_input, global_input)

# Export the model to ONNX using torchdynamo
onnx_program = torch.onnx.export(model, model_inputs, dynamo=True)

# Save the ONNX model
onnx_program.save("/Users/hadim/test_model.onnx")

W1123 19:33:04.087000 37894 site-packages/torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 5 of general pattern rewrite rules.


In [29]:
target_move = 8

# Path to SGF
sgf_path = "../sgf/game-19-white-38.sgf"
sgf_path = "../sgf/game-13-white-12.sgf"
sgf_path = "../sgf/game-9-white-19.sgf"

# Load game state from SGF
gamestate = load_sgf(sgf_path, target_move=target_move)

# Featurize the game state
features, bin_input, global_input = featurize(gamestate, model)

# Prepare  inputs for ONNX export
model_inputs = (bin_input, global_input)

onnx_inputs = [tensor.numpy(force=True) for tensor in model_inputs]

ort_session = onnxruntime.InferenceSession(
    "/Users/hadim/test_model.onnx",
    providers=["CPUExecutionProvider"],
)

onnxruntime_input = {
    input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)
}

# ONNX Runtime returns a list of outputs
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

# Unpack outputs
policy_logits = np.array(onnxruntime_outputs[0])
value_logits = np.array(onnxruntime_outputs[1])

# Process Policy
# policy_logits is already numpy from ONNX Runtime
# Shape: [batch, num_policy_outputs, num_moves]
# We want the first policy output (index 0) and softmax over moves (axis 2)
policy_probs = np.exp(policy_logits[:, 0, :]) / np.sum(np.exp(policy_logits[:, 0, :]), axis=1, keepdims=True)
policy_probs = policy_probs[0]

# Process Value (Winrate)
# value_logits is already numpy from ONNX Runtime
# Shape: [batch, 3] corresponding to [win, loss, no_result]
value_probs = np.exp(value_logits) / np.sum(np.exp(value_logits), axis=1, keepdims=True)
value_probs = value_probs[0]

winrate = value_probs[0]  # Assuming index 0 is win for current player? Or black? Or white?

for a in onnxruntime_outputs:
    print(a.shape)

Loaded SGF with 50 moves.
Game size: 9
Replayed to move 8
(1, 6, 82)
(1, 3)
(1, 10)
(1, 8)
(1, 1, 9, 9)
(1, 1, 9, 9)
(1, 2, 9, 9)
(1, 4, 9, 9)
(1, 842)
(1, 6, 82)
(1, 3)
(1, 10)
(1, 8)
(1, 1, 9, 9)
(1, 1, 9, 9)
(1, 2, 9, 9)
(1, 4, 9, 9)
(1, 842)


In [31]:
winrate

np.float32(0.058144335)