In [1]:
!pip install optimum[onnxruntime-gpu]

Collecting optimum[onnxruntime-gpu]
  Downloading optimum-1.13.2.tar.gz (300 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.0/301.0 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting coloredlogs (from optimum[onnxruntime-gpu])
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting transformers[sentencepiece]>=4.26.0 (from optimum[onnxruntime-gpu])
  Downloading transformers-4.34.1-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m77.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.8.0 (from optimum[onnxruntime-gpu])
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 

In [2]:
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional

import datasets
import numpy as np
import pandas as pd
import torch
import transformers
from datasets import load_dataset
from evaluate import load
from onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from transformers import AutoFeatureExtractor, EvalPrediction, HfArgumentParser, TrainingArguments
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoCalibrationConfig, QuantizationConfig
from optimum.onnxruntime.model import ORTModel
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from optimum.onnxruntime.preprocessors import QuantizationPreprocessor
from optimum.onnxruntime.preprocessors.passes import (
    ExcludeGeLUNodes,
    ExcludeLayerNormNodes,
    ExcludeNodeAfter,
    ExcludeNodeFollowedBy,
)

In [3]:
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    train_dir: Optional[str] = field(default=None, metadata={"help": "A directory path for the training data."})
    validation_dir: Optional[str] = field(default=None, metadata={"help": "A directory path for the validation data."})

In [4]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )

In [5]:
@dataclass
class OptimizationArguments:
    """
    Arguments pertaining to what type of optimization we are going to apply on the model.
    """

    quantization_approach: str = field(
        default="dynamic",
        metadata={"help": "The quantization approach. Supported approach are static and dynamic."},
    )
    per_channel: bool = field(
        default=False,
        metadata={"help": "Whether to quantize the weights per channel."},
    )
    reduce_range: bool = field(
        default=False,
        metadata={
            "help": "Whether to quantize the weights with 7-bits. It may improve the accuracy for some models running "
            "on non-VNNI machine, especially for per-channel mode."
        },
    )
    calibration_method: str = field(
        default="minmax",
        metadata={
            "help": "The method chosen to calculate the activation quantization parameters using the calibration "
            "dataset. Current supported calibration methods are minmax, entropy and percentile."
        },
    )
    num_calibration_samples: int = field(
        default=100,
        metadata={"help": "Number of examples to use for the calibration step resulting from static quantization."},
    )
    num_calibration_shards: int = field(
        default=1,
        metadata={
            "help": "How many shards to split the calibration dataset into. Useful for the entropy and percentile "
            "calibration method."
        },
    )
    calibration_batch_size: int = field(
        default=8,
        metadata={"help": "The batch size for the calibration step."},
    )
    calibration_histogram_percentile: float = field(
        default=99.999,
        metadata={"help": "The percentile used for the percentile calibration method."},
    )
    calibration_moving_average: bool = field(
        default=False,
        metadata={
            "help": "Whether to compute the moving average of the minimum and maximum values for the minmax "
            "calibration method."
        },
    )
    calibration_moving_average_constant: float = field(
        default=0.01,
        metadata={
            "help": "Constant smoothing factor to use when computing the moving average of the minimum and maximum "
            "values. Effective only when the selected calibration method is minmax and `calibration_moving_average` is "
            "set to True."
        },
    )
    execution_provider: str = field(
        default="CPUExecutionProvider",
        metadata={"help": "ONNX Runtime execution provider to use for inference."},
    )

In [6]:
@dataclass
class OnnxExportArguments:
    """
    Arguments to decide how the ModelProto will be saved.
    """

    # TODO: currently onnxruntime put external data in different path than the model proto, which will cause problem on re-loading it.
    # https://github.com/microsoft/onnxruntime/issues/12576
    use_external_data_format: bool = field(
        default=False,
        metadata={"help": "Whether to use external data format to store model whose size is >= 2Gb."},
    )

In [7]:
model_args = ModelArguments(model_name_or_path="nateraw/vit-base-beans")
data_args = DataTrainingArguments(dataset_name="beans")
training_args = TrainingArguments(output_dir="image_classification_vit_beans", do_eval=True)
optim_args = OptimizationArguments(quantization_approach="dynamic")
onnx_export_args = OnnxExportArguments()

In [8]:
dataset = load_dataset(data_args.dataset_name)

Downloading builder script:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.24k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.75k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/144M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/1034 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/133 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/128 [00:00<?, ? examples/s]

In [9]:
labels_column = (
        "labels" if "labels" in dataset["validation"].column_names else dataset["validation"].column_names[1]
    )

In [10]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.model_name_or_path)

# Define torchvision transforms to be applied to each image.
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

Downloading (…)rocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]



In [11]:
transforms = Compose(
    [
        Resize(feature_extractor.size['height']),
        CenterCrop(feature_extractor.size['width']),
        ToTensor(),
        normalize,
    ]
)

In [12]:
def preprocess_function(example_batch):
    """Apply transforms across a batch."""
    example_batch["pixel_values"] = [
        transforms(image.convert("RGB")).to(torch.float32).numpy() for image in example_batch["image"]
    ]
    return example_batch

In [13]:
metric = load("accuracy")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [14]:
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=1)

    result = metric.compute(predictions=preds, references=p.label_ids)
    return result

In [15]:
model = ORTModelForImageClassification.from_pretrained(model_args.model_name_or_path, export=True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/756 [00:00<?, ?B/s]

Framework not specified. Using pt to export to ONNX.


Downloading pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Using the export variant default. Available variants are:
	- default: The default ONNX variant.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Using framework PyTorch: 2.1.0+cu118
  if num_channels != self.num_channels:
  if height != self.image_size[0] or width != self.image_size[1]:
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [16]:
eval_dataset = dataset["validation"]
eval_dataset = eval_dataset.align_labels_with_mapping(
    label2id=model.config.label2id, label_column=labels_column
)
eval_dataset = eval_dataset.with_transform(preprocess_function)

Aligning the labels:   0%|          | 0/133 [00:00<?, ? examples/s]

In [17]:
model.save_pretrained("model_base")
ort_model_base = ORTModel(
    "/content/model_base/model.onnx",
    execution_provider=optim_args.execution_provider,
    compute_metrics=compute_metrics,
    label_names=[labels_column],
)

In [18]:
base_file_size = os.path.getsize("/content/model_base/model.onnx") / 1024**2

In [19]:
start = time.time()
outputs = ort_model_base.evaluation_loop(eval_dataset)
base_eval_time = time.time() - start

***** Running evaluation *****


In [20]:
base_accuracy = outputs.metrics["accuracy"]

In [21]:
quantizer = ORTQuantizer.from_pretrained(model)

In [22]:
apply_static_quantization = optim_args.quantization_approach == "static"

In [23]:
qconfig = QuantizationConfig(
    is_static=apply_static_quantization,
    format=QuantFormat.QDQ if apply_static_quantization else QuantFormat.QOperator,
    mode=QuantizationMode.QLinearOps if apply_static_quantization else QuantizationMode.IntegerOps,
    activations_dtype=QuantType.QInt8 if apply_static_quantization else QuantType.QUInt8,
    weights_dtype=QuantType.QInt8,
    per_channel=optim_args.per_channel,
    reduce_range=optim_args.reduce_range,
    operators_to_quantize=["MatMul", "Add"],
)

In [24]:
ranges = None
# Create a quantization preprocessor to determine the nodes to exclude
quantization_preprocessor = QuantizationPreprocessor()

In [25]:
quantizer.quantize(
      save_dir=training_args.output_dir,
      calibration_tensors_range=ranges,
      quantization_config=qconfig,
      preprocessor=quantization_preprocessor,
      use_external_data_format=onnx_export_args.use_external_data_format,
  )

Creating dynamic quantizer: QOperator (mode: IntegerOps, schema: u8/s8, channel-wise: False)
Preprocessor detected, collecting nodes to include/exclude
Quantizing model...
Saving quantized model at: image_classification_vit_beans (external data format: False)
Configuration saved in image_classification_vit_beans/ort_config.json


PosixPath('image_classification_vit_beans')

In [26]:
ort_model_quantized = ORTModel(
    Path(training_args.output_dir) / "model_quantized.onnx",
    execution_provider=optim_args.execution_provider,
    compute_metrics=compute_metrics,
    label_names=[labels_column],
)

In [27]:
start = time.time()
outputs = ort_model_quantized.evaluation_loop(eval_dataset)
quantized_eval_time = time.time() - start

***** Running evaluation *****


In [28]:
quantized_file_size = os.path.getsize(Path(training_args.output_dir) / "model_quantized.onnx") / 1024**2

In [29]:
quantized_accuracy = outputs.metrics["accuracy"]

In [31]:
print(f"Inference Speed\nBase {base_eval_time}\nQuantized {quantized_eval_time}")
print(f"Accuracy\nBase {base_accuracy}\nQuantized {quantized_accuracy}")
print(f"File Size\nBase {base_file_size}\nQuantized {quantized_file_size}")

Inference Speed
Base 143.7753348350525
Quantized 83.55315947532654
Accuracy
Base 0.9774436090225563
Quantized 0.9774436090225563
File Size
Base 327.5556297302246
Quantized 84.76120090484619
