#### Quantizing the model to Int 8

In [38]:
import onnx
import onnxruntime as ort
from onnxruntime import quantization
from onnx.onnx_ml_pb2 import ModelProto
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession
from onnxruntime.quantization.quant_utils import QuantType

from pathlib import Path
import os


In [10]:
MODEL_DIR: Path = Path(Path.cwd()).parent.joinpath(
    "models"
)  # save the original fp32 models in the models directory

In [11]:
decoder_onnx_path: Path = MODEL_DIR.joinpath("sam2_hiera_tiny_decoder.onnx")
encoder_onnx_path: Path = MODEL_DIR.joinpath("sam2_hiera_tiny_encoder.onnx")

decoder_onnx: ModelProto = onnx.load(decoder_onnx_path)
encoder_onnx: ModelProto = onnx.load(encoder_onnx_path)
onnx.checker.check_model(encoder_onnx)
onnx.checker.check_model(decoder_onnx)

In [15]:
ort_provider: list[str] = ["CPUExecutionProvider"]

ort_sess_encoder: InferenceSession = ort.InferenceSession(
    decoder_onnx_path, providers=ort_provider
)
ort_sess_decoder: InferenceSession = ort.InferenceSession(
    encoder_onnx_path, providers=ort_provider
)

In [16]:
encoder_prep_path = MODEL_DIR.joinpath("sam2_hiera_tiny_encoder.onnx")
decoder_prep_path = MODEL_DIR.joinpath("sam2_hiera_tiny_decoder.onnx")

quantization.shape_inference.quant_pre_process(
    encoder_onnx_path, encoder_prep_path, skip_symbolic_shape=False
)
quantization.shape_inference.quant_pre_process(
    decoder_onnx_path, decoder_prep_path, skip_symbolic_shape=True
)  # skippinng symbolic shape as it is mostly useful for transformers based models

In [17]:
ops_dec = set()
for node in decoder_onnx.graph.node:
    ops_dec.add(node.op_type)

ops_enc = set()
for node in encoder_onnx.graph.node:
    ops_enc.add(node.op_type)

In [18]:
op_to_quantize_dec = [x for x in list(ops_enc) if not x.lower().startswith("conv")]
# removing conv layer as it is giving an error while converting it int8 because of an ONNX issue.


op_to_quantize_enc = [x for x in list(ops_dec) if not x.lower().startswith("conv")]

In [19]:
quantized_encoder_path: Path = MODEL_DIR.joinpath("sam2_hiera_tiny_encoder_quant.onnx")
quantized_decoder_path: Path = MODEL_DIR.joinpath("sam2_hiera_tiny_decoder_quant.onnx")

quantization.quantize_dynamic(
    encoder_prep_path,
    quantized_encoder_path,
    weight_type=QuantType.QInt8,
    op_types_to_quantize=op_to_quantize_enc,
)  # Make weight_type = QuantType.QUInt8 if you do not wish to leave the conv layers in "unquantized"
quantization.quantize_dynamic(
    decoder_prep_path,
    quantized_decoder_path,
    weight_type=QuantType.QInt8,
    op_types_to_quantize=op_to_quantize_dec,
)

#### Simplify quantized models

In [None]:
# TODO: Simplify quantized models using onnxsim

#### Loading quantized model

In [20]:
decoder_quant_onnx: ModelProto = onnx.load(quantized_decoder_path)
encoder_quant_onnx: ModelProto = onnx.load(quantized_encoder_path)
onnx.checker.check_model(decoder_quant_onnx)
onnx.checker.check_model(encoder_quant_onnx)

In [21]:
ort_provider: list[str] = ["CPUExecutionProvider"]

ort_sess_encoder: InferenceSession = ort.InferenceSession(
    quantized_encoder_path, providers=ort_provider
)
ort_sess_decoder: InferenceSession = ort.InferenceSession(
    quantized_decoder_path, providers=ort_provider
)

#### Comparing size reduction of fp32 to int8 models

In [37]:
size_encoder_quant: int = os.path.getsize(quantized_encoder_path) // 1024
size_encoder: int = os.path.getsize(encoder_onnx_path) // 1024

size_decoder_quant: int = os.path.getsize(quantized_decoder_path) // 1024
size_decoder: int = os.path.getsize(decoder_onnx_path) // 1024


print(f"Size of int8 encoder: {size_encoder_quant} KB")
print(f"Size of fp32 encoder: {size_encoder} KB")

print(f"Size of int8 encoder: {size_decoder_quant} KB")
print(f"Size of fp32 encoder: {size_decoder} KB")

reduction_enc: float = round(
    ((size_encoder - size_encoder_quant) / size_encoder) * 100, 3
)
reduction_dec: float = round(
    ((size_decoder - size_decoder_quant) / size_decoder) * 100, 3
)
print(f"Reduction in size in percentage for encoder: {reduction_enc} %")
print(f"Reduction in size in percentage for decoder: {reduction_dec} %")

Size of int8 encoder: 52776 KB
Size of fp32 encoder: 131115 KB
Size of int8 encoder: 8597 KB
Size of fp32 encoder: 20152 KB
Reduction in size in percentage for encoder: 59.748 %
Reduction in size in percentage for decoder: 57.339 %
