In [1]:
import pathlib
import shutil

import numpy as np
import onnx
import platformdirs
import torch
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from onnxruntime.quantization.preprocess import quant_pre_process

from katago_onnx.utils import featurize, load_model, load_sgf

cache_dir = pathlib.Path(platformdirs.user_cache_dir("katago-onnx"))

# Load model
network_name = "kata1-b28c512nbt-adam-s11165M-d5387M"
torch_model_path = cache_dir / network_name / "model.ckpt"

# Base path for ONNX models (without suffix)
onnx_base_path = cache_dir / network_name / network_name

In [2]:
# Load the PyTorch model
model = load_model(torch_model_path, device="cpu")

# Prepare inputs for ONNX export
bin_input = torch.randn(1, 22, 19, 19, dtype=torch.float32)
global_input = torch.randn(1, 19, dtype=torch.float32)
model_inputs = (bin_input, global_input)

# Define input and output names
input_names = ["bin_input", "global_input"]
output_names = [
    "policy",
    "value",
    "miscvalue",
    "moremiscvalue",
    "ownership",
    "scoring",
    "futurepos",
    "seki",
    "scorebelief",
]

# Define dynamic axes (for dynamo=False)
dynamic_axes = {
    "bin_input": {0: "batch_size", 2: "height", 3: "width"},
    "global_input": {0: "batch_size"},
    "policy": {0: "batch_size", 2: "moves"},
    "value": {0: "batch_size"},
    "miscvalue": {0: "batch_size"},
    "moremiscvalue": {0: "batch_size"},
    "ownership": {0: "batch_size", 2: "height", 3: "width"},
    "scoring": {0: "batch_size", 2: "height", 3: "width"},
    "futurepos": {0: "batch_size", 2: "height", 3: "width"},
    "seki": {0: "batch_size", 2: "height", 3: "width"},
    "scorebelief": {0: "batch_size"},
}

# Export the model to ONNX (FP32)
model_fp32_path = onnx_base_path.with_suffix(".fp32.onnx")
print(f"Exporting FP32 model to {model_fp32_path}...")
torch.onnx.export(
    model,
    model_inputs,
    str(model_fp32_path),
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    opset_version=17,
    dynamo=False,
)
print(f"FP32 model saved to: {model_fp32_path}")

Exporting FP32 model to /Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.fp32.onnx...


  torch.onnx.export(
  assert x.shape[1] == self.c_in


FP32 model saved to: /Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.fp32.onnx


In [3]:
# Convert to FP16 (recommended for Apple Silicon and modern GPUs)
# FP16 provides ~2x smaller size and faster inference on hardware with native FP16 support
model_fp16_path = onnx_base_path.with_suffix(".fp16.onnx")
print(f"Converting to FP16: {model_fp16_path}...")

# Load the FP32 model and convert to FP16
onnx_model = onnx.load(str(model_fp32_path))

# Keep inputs as FP32 for compatibility, only convert internal computations
onnx_model_fp16 = float16.convert_float_to_float16(
    onnx_model,
    keep_io_types=True,  # Keep input/output as FP32 for easier integration
    min_positive_val=1e-7,
    max_finite_val=1e4,
)
onnx.save(onnx_model_fp16, str(model_fp16_path))
print(f"FP16 model saved to: {model_fp16_path}")

Converting to FP16: /Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.fp16.onnx...




FP16 model saved to: /Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.fp16.onnx


In [4]:
# Quantize to UINT8 (for memory-constrained devices, not recommended for M3)
model_prep_path = onnx_base_path.with_suffix(".prep.onnx")
model_uint8_path = onnx_base_path.with_suffix(".uint8.onnx")

# Pre-process the model (Shape inference and optimization)
print("Pre-processing model for UINT8 quantization...")
quant_pre_process(model_fp32_path, model_prep_path)

# Quantize the model (UINT8) - ~4x smaller but often slower on modern CPUs
print(f"Quantizing model to {model_uint8_path}...")
quantize_dynamic(
    model_input=model_prep_path,
    model_output=model_uint8_path,
    weight_type=QuantType.QUInt8,
)
print(f"UINT8 quantized model saved to: {model_uint8_path}")

# Clean up intermediate file
model_prep_path.unlink()
print("Done!")

Pre-processing model for UINT8 quantization...
Quantizing model to /Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.uint8.onnx...
UINT8 quantized model saved to: /Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.uint8.onnx
Done!


In [5]:
# Compare model sizes
import os

models = [
    ("FP32", model_fp32_path),
    ("FP16", model_fp16_path),
    ("UINT8", model_uint8_path),
]

print("Model Size Comparison:")
print("-" * 40)
for name, path in models:
    if path.exists():
        size_mb = os.path.getsize(path) / (1024 * 1024)
        print(f"{name:6}: {size_mb:>7.1f} MB  ({path.name})")

Model Size Comparison:
----------------------------------------
FP32  :   279.5 MB  (kata1-b28c512nbt-adam-s11165M-d5387M.fp32.onnx)
FP16  :   140.2 MB  (kata1-b28c512nbt-adam-s11165M-d5387M.fp16.onnx)
UINT8 :    71.7 MB  (kata1-b28c512nbt-adam-s11165M-d5387M.uint8.onnx)


'/Users/hadim/Library/Caches/katago-onnx/kata1-b28c512nbt-adam-s11165M-d5387M/kata1-b28c512nbt-adam-s11165M-d5387M.fp16.onnx'