Skip to content

Commit

Permalink
馃洜 Refactor Normalization (#496)
Browse files Browse the repository at this point in the history
* Use a single placeholder for normalization metrics

* Fix attribute existance issue

* Address codacy issues

* Address some more codacy issues

* Add ignore to solve codacy issues

* Move ignore statement + add exception

* remove duplication
  • Loading branch information
ashwinvaidya17 committed Aug 15, 2022
1 parent b2b5ad4 commit d6951eb
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 136 deletions.
12 changes: 5 additions & 7 deletions anomalib/deploy/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import torch
from torch import Tensor
from torch.types import Number

from anomalib.models.components import AnomalyModule

Expand All @@ -25,16 +26,13 @@ def get_model_metadata(model: AnomalyModule) -> Dict[str, Tensor]:
Dict[str, Tensor]: metadata
"""
meta_data = {}
cached_meta_data = {
cached_meta_data: Dict[str, Union[Number, Tensor]] = {
"image_threshold": model.image_threshold.cpu().value.item(),
"pixel_threshold": model.pixel_threshold.cpu().value.item(),
"pixel_mean": model.training_distribution.pixel_mean.cpu(),
"image_mean": model.training_distribution.image_mean.cpu(),
"pixel_std": model.training_distribution.pixel_std.cpu(),
"image_std": model.training_distribution.image_std.cpu(),
"min": model.min_max.min.cpu().item(),
"max": model.min_max.max.cpu().item(),
}
if hasattr(model, "normalization_metrics") and model.normalization_metrics.state_dict() is not None:
for key, value in model.normalization_metrics.state_dict().items():
cached_meta_data[key] = value.cpu()
# Remove undefined values by copying in a new dict
for key, val in cached_meta_data.items():
if not np.isinf(val).all():
Expand Down
14 changes: 3 additions & 11 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn
from torchmetrics import Metric

from anomalib.utils.metrics import (
AdaptiveThreshold,
AnomalibMetricCollection,
AnomalyScoreDistribution,
MinMax,
)
from anomalib.utils.metrics import AdaptiveThreshold, AnomalibMetricCollection

logger = logging.getLogger(__name__)

Expand All @@ -41,12 +37,8 @@ def __init__(self):
self.image_threshold = AdaptiveThreshold().cpu()
self.pixel_threshold = AdaptiveThreshold().cpu()

self.training_distribution = AnomalyScoreDistribution().cpu()
self.min_max = MinMax().cpu()
self.normalization_metrics: Metric

# Create placeholders for image and pixel metrics.
# If set from the config file, MetricsConfigurationCallback will
# create the metric collections upon setup.
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection

Expand Down
26 changes: 20 additions & 6 deletions anomalib/utils/callbacks/cdf_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.cdf import normalize, standardize
from anomalib.utils.metrics import AnomalyScoreDistribution

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +28,19 @@ def __init__(self):
self.image_dist: Optional[LogNormal] = None
self.pixel_dist: Optional[LogNormal] = None

def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
# pylint: disable=unused-argument
def setup(self, trainer: pl.Trainer, pl_module: AnomalyModule, stage: Optional[str] = None) -> None:
"""Adds training_distribution metrics to normalization metrics."""
if not hasattr(pl_module, "normalization_metrics"):
pl_module.normalization_metrics = AnomalyScoreDistribution().cpu()
elif not isinstance(pl_module.normalization_metrics, AnomalyScoreDistribution):
raise AttributeError(
f"Expected normalization_metrics to be of type AnomalyScoreDistribution,"
f" got {type(pl_module.normalization_metrics)}"
)

# pylint: disable=unused-argument
def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
if pl_module.image_metrics is not None:
pl_module.image_metrics.set_threshold(0.5)
Expand Down Expand Up @@ -93,24 +106,25 @@ def _collect_stats(self, trainer, pl_module):
predictions = Trainer(gpus=trainer.gpus).predict(
model=self._create_inference_model(pl_module), dataloaders=trainer.datamodule.train_dataloader()
)
pl_module.training_distribution.reset()
pl_module.normalization_metrics.reset()
for batch in predictions:
if "pred_scores" in batch.keys():
pl_module.training_distribution.update(anomaly_scores=batch["pred_scores"])
pl_module.normalization_metrics.update(anomaly_scores=batch["pred_scores"])
if "anomaly_maps" in batch.keys():
pl_module.training_distribution.update(anomaly_maps=batch["anomaly_maps"])
pl_module.training_distribution.compute()
pl_module.normalization_metrics.update(anomaly_maps=batch["anomaly_maps"])
pl_module.normalization_metrics.compute()

@staticmethod
def _create_inference_model(pl_module):
"""Create a duplicate of the PL module that can be used to perform inference on the training set."""
new_model = get_model(pl_module.hparams)
new_model.normalization_metrics = AnomalyScoreDistribution().cpu()
new_model.load_state_dict(pl_module.state_dict())
return new_model

@staticmethod
def _standardize_batch(outputs: STEP_OUTPUT, pl_module) -> None:
stats = pl_module.training_distribution.to(outputs["pred_scores"].device)
stats = pl_module.normalization_metrics.to(outputs["pred_scores"].device)
outputs["pred_scores"] = standardize(outputs["pred_scores"], stats.image_mean, stats.image_std)
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = standardize(
Expand Down
29 changes: 20 additions & 9 deletions anomalib/utils/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning import Callback
Expand All @@ -12,18 +12,29 @@

from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.min_max import normalize
from anomalib.utils.metrics import MinMax


@CALLBACK_REGISTRY
class MinMaxNormalizationCallback(Callback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization."""

def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
# pylint: disable=unused-argument
def setup(self, trainer: pl.Trainer, pl_module: AnomalyModule, stage: Optional[str] = None) -> None:
"""Adds min_max metrics to normalization metrics."""
if not hasattr(pl_module, "normalization_metrics"):
pl_module.normalization_metrics = MinMax().cpu()
elif not isinstance(pl_module.normalization_metrics, MinMax):
raise AttributeError(
f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
)

# pylint: disable=unused-argument
def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
if pl_module.image_metrics is not None:
pl_module.image_metrics.set_threshold(0.5)
if pl_module.pixel_metrics is not None:
pl_module.pixel_metrics.set_threshold(0.5)
for metric in (pl_module.image_metrics, pl_module.pixel_metrics):
if metric is not None:
metric.set_threshold(0.5)

def on_validation_batch_end(
self,
Expand All @@ -36,9 +47,9 @@ def on_validation_batch_end(
) -> None:
"""Called when the validation batch ends, update the min and max observed values."""
if "anomaly_maps" in outputs.keys():
pl_module.min_max(outputs["anomaly_maps"])
pl_module.normalization_metrics(outputs["anomaly_maps"])
else:
pl_module.min_max(outputs["pred_scores"])
pl_module.normalization_metrics(outputs["pred_scores"])

def on_test_batch_end(
self,
Expand Down Expand Up @@ -67,7 +78,7 @@ def on_predict_batch_end(
@staticmethod
def _normalize_batch(outputs, pl_module):
"""Normalize a batch of predictions."""
stats = pl_module.min_max.cpu()
stats = pl_module.normalization_metrics.cpu()
outputs["pred_scores"] = normalize(
outputs["pred_scores"], pl_module.image_threshold.value.cpu(), stats.min, stats.max
)
Expand Down
8 changes: 1 addition & 7 deletions anomalib/utils/sweep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@
# SPDX-License-Identifier: Apache-2.0

from .config import flatten_sweep_params, get_run_config, set_in_nested_config
from .helpers import (
get_meta_data,
get_openvino_throughput,
get_sweep_callbacks,
get_torch_throughput,
)
from .helpers import get_openvino_throughput, get_sweep_callbacks, get_torch_throughput

__all__ = [
"get_run_config",
"set_in_nested_config",
"get_sweep_callbacks",
"get_meta_data",
"get_openvino_throughput",
"get_torch_throughput",
"flatten_sweep_params",
Expand Down
4 changes: 2 additions & 2 deletions anomalib/utils/sweep/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

from .callbacks import get_sweep_callbacks
from .inference import get_meta_data, get_openvino_throughput, get_torch_throughput
from .inference import get_openvino_throughput, get_torch_throughput

__all__ = ["get_meta_data", "get_openvino_throughput", "get_torch_throughput", "get_sweep_callbacks"]
__all__ = ["get_openvino_throughput", "get_torch_throughput", "get_sweep_callbacks"]
53 changes: 6 additions & 47 deletions anomalib/utils/sweep/helpers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import time
from pathlib import Path
from typing import Dict, Iterable, List, Tuple, Union
from typing import Iterable, List, Union

import numpy as np
import torch
Expand Down Expand Up @@ -45,53 +45,15 @@ def __call__(self) -> Iterable[np.ndarray]:
yield self.image


def get_meta_data(model: AnomalyModule, input_size: Tuple[int, int]) -> Dict:
"""Get meta data for inference.
Args:
model (AnomalyModule): Trained model from which the metadata is extracted.
input_size (Tuple[int, int]): Input size used to resize the pixel level mean and std.
Returns:
(Dict): Metadata as dictionary.
"""
meta_data = {
"image_threshold": model.image_threshold.value.cpu().numpy(),
"pixel_threshold": model.pixel_threshold.value.cpu().numpy(),
"min": model.min_max.min.cpu().numpy(),
"max": model.min_max.max.cpu().numpy(),
"stats": {},
}

image_mean = model.training_distribution.image_mean.cpu().numpy()
if image_mean.size > 0:
meta_data["stats"]["image_mean"] = image_mean

image_std = model.training_distribution.image_std.cpu().numpy()
if image_std.size > 0:
meta_data["stats"]["image_std"] = image_std

pixel_mean = model.training_distribution.pixel_mean.cpu().numpy()
if pixel_mean.size > 0:
meta_data["stats"]["pixel_mean"] = pixel_mean.reshape(input_size)

pixel_std = model.training_distribution.pixel_std.cpu().numpy()
if pixel_std.size > 0:
meta_data["stats"]["pixel_std"] = pixel_std.reshape(input_size)

return meta_data


def get_torch_throughput(
config: Union[DictConfig, ListConfig], model: AnomalyModule, test_dataset: DataLoader, meta_data: Dict
config: Union[DictConfig, ListConfig], model: AnomalyModule, test_dataset: DataLoader
) -> float:
"""Tests the model on dummy data. Images are passed sequentially to make the comparision with OpenVINO model fair.
Args:
config (Union[DictConfig, ListConfig]): Model config.
model (Path): Model on which inference is called.
test_dataset (DataLoader): The test dataset used as a reference for the mock dataset.
meta_data (Dict): Metadata used for normalization.
Returns:
float: Inference throughput
Expand All @@ -103,7 +65,7 @@ def get_torch_throughput(
start_time = time.time()
# Since we don't care about performance metrics and just the throughput, use mock data.
for image in torch_dataloader():
inferencer.predict(image, meta_data=meta_data)
inferencer.predict(image)

# get throughput
inference_time = time.time() - start_time
Expand All @@ -113,26 +75,23 @@ def get_torch_throughput(
return throughput


def get_openvino_throughput(
config: Union[DictConfig, ListConfig], model_path: Path, test_dataset: DataLoader, meta_data: Dict
) -> float:
def get_openvino_throughput(config: Union[DictConfig, ListConfig], model_path: Path, test_dataset: DataLoader) -> float:
"""Runs the generated OpenVINO model on a dummy dataset to get throughput.
Args:
config (Union[DictConfig, ListConfig]): Model config.
model_path (Path): Path to folder containing the OpenVINO models. It then searches `model.xml` in the folder.
test_dataset (DataLoader): The test dataset used as a reference for the mock dataset.
meta_data (Dict): Metadata used for normalization.
Returns:
float: Inference throughput
"""
inferencer = OpenVINOInferencer(config, model_path / "model.xml")
inferencer = OpenVINOInferencer(config, model_path / "model.xml", model_path / "meta_data.json")
openvino_dataloader = MockImageLoader(config.dataset.image_size, total_count=len(test_dataset))
start_time = time.time()
# Create test images on CPU. Since we don't care about performance metrics and just the throughput, use mock data.
for image in openvino_dataloader():
inferencer.predict(image, meta_data=meta_data)
inferencer.predict(image)

# get throughput
inference_time = time.time() - start_time
Expand Down
37 changes: 1 addition & 36 deletions tests/helpers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# SPDX-License-Identifier: Apache-2.0


from typing import Dict, Iterable, List, Tuple
from typing import Iterable, List

import numpy as np

from anomalib.models.components import AnomalyModule


class MockImageLoader:
"""Create mock images for inference on CPU based on the specifics of the original torch test dataset.
Expand All @@ -35,36 +33,3 @@ def __call__(self) -> Iterable[np.ndarray]:
"""
for _ in range(self.total_count):
yield self.image


def get_meta_data(model: AnomalyModule, input_size: Tuple[int, int]) -> Dict:
"""Get meta data for inference.
Args:
model (AnomalyModule): Trained model from which the metadata is extracted.
input_size (Tuple[int, int]): Input size used to resize the pixel level mean and std.
Returns:
(Dict): Metadata as dictionary.
"""
meta_data = {
"image_threshold": model.image_threshold.value.cpu().numpy(),
"pixel_threshold": model.pixel_threshold.value.cpu().numpy(),
"stats": {},
}

image_mean = model.training_distribution.image_mean.cpu().numpy()
if image_mean.size > 0:
meta_data["stats"]["image_mean"] = image_mean

image_std = model.training_distribution.image_std.cpu().numpy()
if image_std.size > 0:
meta_data["stats"]["image_std"] = image_std

pixel_mean = model.training_distribution.pixel_mean.cpu().numpy()
if pixel_mean.size > 0:
meta_data["stats"]["pixel_mean"] = pixel_mean.reshape(input_size)

pixel_std = model.training_distribution.pixel_std.cpu().numpy()
if pixel_std.size > 0:
meta_data["stats"]["pixel_std"] = pixel_std.reshape(input_size)

return meta_data

0 comments on commit d6951eb

Please sign in to comment.