Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added accuracy control quantization #2070

Merged
merged 12 commits into from
Jun 6, 2024
25 changes: 13 additions & 12 deletions src/anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ class ExportType(str, Enum):
class CompressionType(str, Enum):
"""Model compression type when exporting to OpenVINO.

Attributes:
FP16 (str): Weight compression (FP16). All weights are converted to FP16.
INT8 (str): Weight compression (INT8). All weights are quantized to INT8,
but are dequantized to floating point before inference.
INT8_PTQ (str): Full integer post-training quantization (INT8).
All weights and operations are quantized to INT8. Inference is done
in INT8 precision.
INT8_ACQ (str): Accuracy-control quantization (INT8). Weights and
operations are quantized to INT8, except those that would degrade
quality of the model more than is acceptable. Inference is done in
a mixed precision.

Examples:
>>> from anomalib.deploy import CompressionType
>>> CompressionType.INT8_PTQ
'int8_ptq'
"""

FP16 = "fp16"
"""
Weight compression (FP16)
All weights are converted to FP16.
"""
INT8 = "int8"
"""
Weight compression (INT8)
All weights are quantized to INT8, but are dequantized to floating point before inference.
"""
INT8_PTQ = "int8_ptq"
"""
Full integer post-training quantization (INT8)
All weights and operations are quantized to INT8. Inference is done in INT8 precision.
"""
INT8_ACQ = "int8_acq"


class InferenceModel(nn.Module):
Expand Down
16 changes: 12 additions & 4 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from torchvision.transforms.v2 import Transform

from anomalib import LearningType, TaskType
Expand Down Expand Up @@ -871,6 +872,7 @@ def export(
transform: Transform | None = None,
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
metric: Metric | str | None = None,
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
ov_args: dict[str, Any] | None = None,
ckpt_path: str | Path | None = None,
) -> Path | None:
Expand All @@ -889,7 +891,12 @@ def export(
compression_type (CompressionType | None, optional): Compression type for OpenVINO exporting only.
Defaults to ``None``.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if CompressionType.INT8_PTQ is selected.
Must be provided if ``CompressionType.INT8_PTQ`` or `CompressionType.INT8_ACQ`` is selected
(OpenVINO export only).
Defaults to ``None``.
metric (Metric | str | None, optional): Metric to measure quality loss when quantizing.
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better
performance of the model (OpenVINO export only).
Defaults to ``None``.
ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer.
Defaults to None.
Expand All @@ -915,12 +922,12 @@ def export(
3. To export as an OpenVINO ``.xml`` and ``.bin`` file you can run the following command.
```python
anomalib export --model Padim --export_mode openvino --ckpt_path <PATH_TO_CHECKPOINT> \
--input_size "[256,256]"
--input_size "[256,256] --compression_type "fp16"
```
4. You can also override OpenVINO model optimizer by adding the ``--ov_args.<key>`` arguments.
4. You can also quantize OpenVINO model with the following.
```python
anomalib export --model Padim --export_mode openvino --ckpt_path <PATH_TO_CHECKPOINT> \
--input_size "[256,256]" --ov_args.compress_to_fp16 False
--input_size "[256,256]" --compression_type "int8_ptq" --data MVTec
```
"""
export_type = ExportType(export_type)
Expand Down Expand Up @@ -954,6 +961,7 @@ def export(
task=self.task,
compression_type=compression_type,
datamodule=datamodule,
metric=metric,
ov_args=ov_args,
)
else:
Expand Down
189 changes: 169 additions & 20 deletions src/anomalib/models/components/base/export_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@

import json
import logging
from collections.abc import Callable
from collections.abc import Callable, Iterable
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
from torch import nn
from torchmetrics import Metric
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data import AnomalibDataModule
from anomalib.deploy.export import CompressionType, ExportType, InferenceModel
from anomalib.metrics import create_metric_collection
from anomalib.utils.exceptions import try_import

if TYPE_CHECKING:
from importlib.util import find_spec

from torch.types import Number

if find_spec("openvino") is not None:
from openvino import CompiledModel

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -159,6 +166,7 @@ def to_openvino(
transform: Transform | None = None,
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
metric: Metric | str | None = None,
ov_args: dict[str, Any] | None = None,
task: TaskType | None = None,
) -> Path:
Expand All @@ -174,7 +182,11 @@ def to_openvino(
compression_type (CompressionType, optional): Compression type for better inference performance.
Defaults to ``None``.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if CompressionType.INT8_PTQ is selected.
Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected.
Defaults to ``None``.
metric (Metric | str | None, optional): Metric to measure quality loss when quantizing.
Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better
performance of the model.
Defaults to ``None``.
ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion.
Defaults to ``None``.
Expand Down Expand Up @@ -206,6 +218,20 @@ def to_openvino(
... task=datamodule.test_data.task
... )

Export and Quantize the Model (OpenVINO IR):
This example demonstrates how to export and quantize the model to OpenVINO IR.

>>> from anomalib.models import Patchcore
>>> from anomalib.data import Visa
>>> datamodule = Visa()
>>> model = Patchcore()
>>> model.to_openvino(
... export_root="path/to/export",
... compression_type=CompressionType.INT8_PTQ,
... datamodule=datamodule,
... task=datamodule.test_data.task
... )

Using Custom Transforms:
This example shows how to use a custom ``Transform`` object for the ``transform`` argument.

Expand All @@ -221,11 +247,7 @@ def to_openvino(
if not try_import("openvino"):
logger.exception("Could not find OpenVINO. Please check OpenVINO installation.")
raise ModuleNotFoundError
if not try_import("nncf"):
logger.exception("Could not find NNCF. Please check NNCF installation.")
raise ModuleNotFoundError

import nncf
import openvino as ov

with TemporaryDirectory() as onnx_directory:
Expand All @@ -235,20 +257,8 @@ def to_openvino(
ov_args = {} if ov_args is None else ov_args

model = ov.convert_model(model_path, **ov_args)
if compression_type == CompressionType.INT8:
model = nncf.compress_weights(model)
elif compression_type == CompressionType.INT8_PTQ:
if datamodule is None:
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
raise ValueError(msg)

dataloader = datamodule.val_dataloader()
if len(dataloader.dataset) < 300:
logger.warning(
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images",
)
calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"])
model = nncf.quantize(model, calibration_dataset)
if compression_type and compression_type != CompressionType.FP16:
model = self._compress_ov_model(model, compression_type, datamodule, metric, task)

# fp16 compression is enabled by default
compress_to_fp16 = compression_type == CompressionType.FP16
Expand All @@ -257,6 +267,145 @@ def to_openvino(

return ov_model_path

def _compress_ov_model(
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
self,
model: "CompiledModel",
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
metric: Metric | str | None = None,
task: TaskType | None = None,
) -> "CompiledModel":
"""Compress OpenVINO model with NNCF.

model (CompiledModel): Model already exported to OpenVINO format.
compression_type (CompressionType, optional): Compression type for better inference performance.
Defaults to ``None``.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected.
Defaults to ``None``.
metric (Metric | str | None, optional): Metric to measure quality loss when quantizing.
Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better
performance of the model.
Defaults to ``None``.
task (TaskType | None): Task type.
Defaults to ``None``.

Returns:
model (CompiledModel): Model in the OpenVINO format compressed with NNCF quantization.
"""
if not try_import("nncf"):
logger.exception("Could not find NCCF. Please check NNCF installation.")
raise ModuleNotFoundError

import nncf

if compression_type == CompressionType.INT8:
model = nncf.compress_weights(model)
elif compression_type == CompressionType.INT8_PTQ:
model = self._post_training_quantization_ov(model, datamodule)
elif compression_type == CompressionType.INT8_ACQ:
model = self._accuracy_control_quantization_ov(model, datamodule, metric, task)
else:
msg = f"Unrecognized compression type: {compression_type}"
raise ValueError(msg)

return model

def _post_training_quantization_ov(
self,
model: "CompiledModel",
datamodule: AnomalibDataModule | None = None,
) -> "CompiledModel":
"""Post-Training Quantization model with NNCF.

model (CompiledModel): Model already exported to OpenVINO format.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected.
Defaults to ``None``.

Returns:
model (CompiledModel): Quantized model.
"""
import nncf

if datamodule is None:
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
raise ValueError(msg)

model_input = model.input(0)

if model_input.partial_shape[0].is_static:
datamodule.train_batch_size = model_input.shape[0]

dataloader = datamodule.train_dataloader()
if len(dataloader.dataset) < 300:
logger.warning(
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images",
)

calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"])
return nncf.quantize(model, calibration_dataset)

def _accuracy_control_quantization_ov(
self,
model: "CompiledModel",
datamodule: AnomalibDataModule | None = None,
metric: Metric | str | None = None,
task: TaskType | None = None,
) -> "CompiledModel":
"""Accuracy-Control Quantization with NNCF.

model (CompiledModel): Model already exported to OpenVINO format.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected.
Defaults to ``None``.
metric (Metric | str | None, optional): Metric to measure quality loss when quantizing.
Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better
performance of the model.
Defaults to ``None``.
task (TaskType | None): Task type.
Defaults to ``None``.

Returns:
model (CompiledModel): Quantized model.
"""
import nncf

if datamodule is None:
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
raise ValueError(msg)
if metric is None:
msg = "Metric must be provided for OpenVINO INT8_ACQ compression"
raise ValueError(msg)

model_input = model.input(0)

if model_input.partial_shape[0].is_static:
datamodule.train_batch_size = model_input.shape[0]
datamodule.eval_batch_size = model_input.shape[0]

dataloader = datamodule.train_dataloader()
if len(dataloader.dataset) < 300:
logger.warning(
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images",
)

calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"])
validation_dataset = nncf.Dataset(datamodule.val_dataloader())

if isinstance(metric, str):
metric = create_metric_collection([metric])[metric]

# validation function to evaluate the quality loss after quantization
def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float:
for batch in validation_data:
preds = torch.from_numpy(nncf_model(batch["image"])[0])
target = batch["label"] if task == TaskType.CLASSIFICATION else batch["mask"][:, None, :, :]
metric.update(preds, target)
return metric.compute()

return nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn)

def _get_metadata(
self,
task: TaskType | None = None,
Expand Down
Loading