Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added accuracy control quantization #2070

Merged
merged 12 commits into from
Jun 6, 2024
6 changes: 6 additions & 0 deletions src/anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ class CompressionType(str, Enum):
Full integer post-training quantization (INT8)
All weights and operations are quantized to INT8. Inference is done in INT8 precision.
"""
INT8_ACQ = "int8_acq"
"""
Accuracy-control quantization (INT8)
Weights and operations are quantized to INT8, except those that would degrade quality of the model more than is
acceptable. Inference is done in a mixed precision.
"""
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved


class InferenceModel(nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from torchvision.transforms.v2 import Transform

from anomalib import LearningType, TaskType
Expand Down Expand Up @@ -871,6 +872,7 @@ def export(
transform: Transform | None = None,
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
metric: Metric | str | None = None,
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
ov_args: dict[str, Any] | None = None,
ckpt_path: str | Path | None = None,
) -> Path | None:
Expand All @@ -891,6 +893,9 @@ def export(
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if CompressionType.INT8_PTQ is selected.
Defaults to ``None``.
metric (Metric | str | None, optional): Metric to measure quality loss when quantizing.
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
Must be provided if CompressionType.INT8_ACQ is selected.
Defaults to ``None``.
ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer.
Defaults to None.
ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path.
Expand Down Expand Up @@ -954,6 +959,7 @@ def export(
task=self.task,
compression_type=compression_type,
datamodule=datamodule,
metric=metric,
ov_args=ov_args,
)
else:
Expand Down
55 changes: 48 additions & 7 deletions src/anomalib/models/components/base/export_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@

import json
import logging
from collections.abc import Callable
from collections.abc import Callable, Iterable
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
from torch import nn
from torchmetrics import Metric
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data import AnomalibDataModule
from anomalib.deploy.export import CompressionType, ExportType, InferenceModel
from anomalib.metrics import create_metric_collection
from anomalib.utils.exceptions import try_import

if TYPE_CHECKING:
Expand Down Expand Up @@ -159,6 +161,7 @@ def to_openvino(
transform: Transform | None = None,
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
metric: Metric | str | None = None,
ov_args: dict[str, Any] | None = None,
task: TaskType | None = None,
) -> Path:
Expand All @@ -176,6 +179,10 @@ def to_openvino(
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if CompressionType.INT8_PTQ is selected.
Defaults to ``None``.
metric (Metric | str | None, optional): Metric to measure quality loss when quantizing.
Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
performance of the model.
Defaults to ``None``.
ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion.
Defaults to ``None``.
task (TaskType | None): Task type.
Expand Down Expand Up @@ -218,11 +225,8 @@ def to_openvino(
... task="segmentation",
... )
"""
if not try_import("openvino"):
logger.exception("Could not find OpenVINO. Please check OpenVINO installation.")
raise ModuleNotFoundError
if not try_import("nncf"):
logger.exception("Could not find NNCF. Please check NNCF installation.")
if not try_import("openvino") or not try_import("nncf"):
logger.exception("Could not find OpenVINO or NCCF. Please check OpenVINO and NNCF installation.")
raise ModuleNotFoundError

import nncf
Expand All @@ -235,20 +239,57 @@ def to_openvino(
ov_args = {} if ov_args is None else ov_args

model = ov.convert_model(model_path, **ov_args)
model_input = model.input(0)

if compression_type == CompressionType.INT8:
model = nncf.compress_weights(model)
elif compression_type == CompressionType.INT8_PTQ:
if datamodule is None:
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
raise ValueError(msg)

dataloader = datamodule.val_dataloader()
if model_input.partial_shape[0].is_static:
datamodule.train_batch_size = model_input.shape[0]
dataloader = datamodule.train_dataloader()
if len(dataloader.dataset) < 300:
logger.warning(
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images",
)

calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"])
model = nncf.quantize(model, calibration_dataset)
elif compression_type == CompressionType.INT8_ACQ:
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
if datamodule is None:
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
raise ValueError(msg)
if metric is None:
msg = "Metric must be provided for OpenVINO INT8_ACQ compression"
raise ValueError(msg)

if model_input.partial_shape[0].is_static:
datamodule.train_batch_size = model_input.shape[0]
datamodule.eval_batch_size = model_input.shape[0]
dataloader = datamodule.train_dataloader()
if len(dataloader.dataset) < 300:
logger.warning(
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images",
)

calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"])
validation_dataset = nncf.Dataset(datamodule.val_dataloader())

if isinstance(metric, str):
metric = create_metric_collection([metric])[metric]

# validation function to evaluate the quality loss after quantization
def val_fn(nncf_model: ov.CompiledModel, validation_data: Iterable) -> float:
adrianboguszewski marked this conversation as resolved.
Show resolved Hide resolved
for batch in validation_data:
preds = torch.from_numpy(nncf_model(batch["image"])[0])
target = batch["mask"][:, None, :, :] if task == TaskType.SEGMENTATION else batch["label"]
metric.update(preds, target)
return metric.compute()

model = nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn)

# fp16 compression is enabled by default
compress_to_fp16 = compression_type == CompressionType.FP16
Expand Down
Loading