In [1]:
from typing import Dict, Union

import numpy as np
import onnxruntime as ort
import timm
import torch
from onnxruntime.quantization import CalibrationDataReader, CalibrationMethod, quantize_static

from src.seed import seed_everything

seed_everything(42)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
img_height = 224
img_width = 224
x = torch.randn(1, 3, img_height, img_width)

model = timm.create_model("resnet50", pretrained=True, num_classes=10)
model(x)


tensor([[ 0.1005, -0.0164,  0.0238, -0.0473, -0.0282, -0.0354,  0.0018,  0.0091,
          0.0022, -0.0017]], grad_fn=<AddmmBackward0>)

In [3]:
torch.onnx.export(
    model,
    x,
    "model.onnx",
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)


In [4]:
providers = ["CPUExecutionProvider"]

session_fp32 = ort.InferenceSession("model.onnx", providers=providers)

# 入力はfloat32で渡す
x = np.random.randn(4, 3, img_height, img_width).astype(np.float32)

# 推論
output = session_fp32.run([], {"input": x})[0]
print(f"{output=}")


output=array([[ 0.0476543 ,  0.02461434,  0.05225206,  0.11924067, -0.04636435,
        -0.04621083, -0.08258556,  0.08939756, -0.02054376,  0.01423609],
       [ 0.05815532,  0.01002552,  0.05692532,  0.11288723, -0.0532899 ,
        -0.05293959, -0.07087824,  0.09985154, -0.01899079,  0.01401258],
       [ 0.04634621,  0.01690613,  0.05314342,  0.12114876, -0.05235383,
        -0.04252257, -0.08219366,  0.0966015 , -0.02349848,  0.00525925],
       [ 0.04864033,  0.02333969,  0.05780552,  0.12373698, -0.05364507,
        -0.05114723, -0.08610532,  0.09010924, -0.02112752,  0.01301998]],
      dtype=float32)


In [5]:
class ImgDataReader(CalibrationDataReader):
    def __init__(self, imgs: np.ndarray) -> None:
        self.imgs = imgs  # 形状(Batch, C, H, W)
        self.img_dicts = iter([{"input": img[np.newaxis]} for img in self.imgs])
        self.datasize = len(self.imgs)

    def get_next(self) -> dict[str, np.ndarray] | None:
        # git_next関数は{"input": ndarray形式}で返す。全て返し終わったらNone
        return next(self.img_dicts, None)


In [6]:
# imgsは形状が（Batch, C, H, W）の画像データとする
imgs = np.random.randn(10, 3, img_height, img_width).astype(np.float32)
input_model_path = "model-infer.onnx"
output_model_path = "model-sq.onnx"
data_reader = ImgDataReader(imgs)
method = CalibrationMethod.MinMax
quantize_static(input_model_path, output_model_path, data_reader, calibrate_method=method)
