In [3]:
%load_ext autoreload
%autoreload 2

import platformdirs
import pathlib
import torch
import numpy as np


from katago.game.board import Board

from katago_onnx.utils import load_model
from katago_onnx.utils import load_sgf
from katago_onnx.utils import featurize
from katago_onnx.utils import run_inference
from katago_onnx.utils import get_top_moves


cache_dir = pathlib.Path(platformdirs.user_cache_dir("katago-onnx"))

# Load model
network_name = "kata1-b28c512nbt-adam-s11165M-d5387M"
model_path = cache_dir / network_name / "model.ckpt"
model = load_model(model_path, device="cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
target_move = 8

# Path to SGF
sgf_path = "../sgf/game-19-white-38.sgf"
sgf_path = "../sgf/game-13-white-12.sgf"
sgf_path = "../sgf/game-9-white-19.sgf"

# Load game state from SGF
gamestate = load_sgf(sgf_path, target_move=target_move)

# Featurize the game state
features, bin_input, global_input = featurize(gamestate, model)

current_player_name = "Black" if gamestate.board.pla == Board.BLACK else "White"
print(f"Current player to play: {current_player_name}")

bin_input.shape, global_input.shape

Loaded SGF with 50 moves.
Game size: 9
Replayed to move 8
Current player to play: Black


(torch.Size([1, 22, 9, 9]), torch.Size([1, 19]))

In [24]:
policy_probs, value_probs, winrate = run_inference(model, bin_input, global_input)

policy_probs.shape, value_probs.shape, winrate

torch.Size([1, 3])


((82,), (3,), np.float32(0.05814434))

In [8]:
top_moves = get_top_moves(policy_probs, gamestate, features, top_k=10)

top_moves


Unnamed: 0,loc,x,y,prob,move_str
0,53,2,4,0.988265,C5
1,48,7,3,0.004197,H4
2,47,6,3,0.003007,G4
3,43,2,3,0.000806,C4
4,75,4,6,0.000786,E7
5,67,6,5,0.000626,G6
6,58,7,4,0.000537,H5
7,42,1,3,0.000444,B4
8,32,1,2,0.00025,B3
9,65,4,5,0.000247,E6
