Skip to content

Commit

Permalink
Fix unexpected key pixel_metrics.AUPRO.fpr_limit (#1055)
Browse files Browse the repository at this point in the history
* fix unexpected key pixel_metrics.AUPRO.fpr_limit

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>

* load AUPRO before create_metric_collection

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>

* code refine

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>

* fix comment

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>

* fix

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>

* Support test

Signed-off-by: Kang Wenjing <wenjing.kang@intel.com>

* Update test

Signed-off-by: Kang Wenjing <wenjing.kang@intel.com>

* Update test

Signed-off-by: Kang Wenjing <wenjing.kang@intel.com>

---------

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>
Signed-off-by: Kang Wenjing <wenjing.kang@intel.com>
Co-authored-by: FanJiangIntel <fan.jiang@intel.com>
Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
3 people committed Oct 24, 2023
1 parent 8a7f1d6 commit 66532fc
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import importlib
import logging
from abc import ABC
from typing import Any, OrderedDict
Expand Down Expand Up @@ -234,11 +235,38 @@ def _load_normalization_class(self, state_dict: OrderedDict[str, Tensor]) -> Non
else:
warn("No known normalization found in model weights.")

def _load_metrics(self, state_dict: OrderedDict[str, Tensor]) -> None:
"""Load metrics from saved checkpoint."""
self._set_metrics("pixel", state_dict)
self._set_metrics("image", state_dict)

def _set_metrics(self, name: str, state_dict: OrderedDict[str, Tensor]):
"""Sets the pixel/image metrics.
Args:
name (str): is it pixel or image.
state_dict (OrderedDict[str, Tensor]): state dict of the model.
"""
metric_keys = [key for key in state_dict.keys() if key.startswith(f"{name}_metrics")]
if not hasattr(self, f"{name}_metrics") and any(metric_keys):
metrics = AnomalibMetricCollection([], prefix=f"{name}_")
for key in metric_keys:
class_name = key.split(".")[1]
try:
metrics_module = importlib.import_module("anomalib.utils.metrics")
metrics_cls = getattr(metrics_module, class_name)
except Exception as exception:
raise ImportError(f"Class {class_name} not found in module anomalib.utils.metrics") from exception
metrics.add_metrics(metrics_cls())
setattr(self, f"{name}_metrics", metrics)

def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = True):
"""Load state dict from checkpoint.
Ensures that normalization and thresholding attributes is properly setup before model is loaded.
"""
# Used to load missing normalization and threshold parameters
self._load_normalization_class(state_dict)
# Used to load metrics if there is any related data in state_dict
self._load_metrics(state_dict)
return super().load_state_dict(state_dict, strict=strict)
8 changes: 7 additions & 1 deletion src/anomalib/utils/callbacks/metrics_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def setup(

if isinstance(pl_module, AnomalyModule):
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
if hasattr(pl_module, "pixel_metrics"):
new_metrics = create_metric_collection(pixel_metric_names, "pixel_")
for name in new_metrics.keys():
if name not in pl_module.pixel_metrics.keys():
pl_module.pixel_metrics.add_metrics(new_metrics[name.split("_")[1]])
else:
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")

pl_module.image_metrics.set_threshold(pl_module.image_threshold.value)
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
metrics:
pixel:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUPRO:
class_path: anomalib.utils.metrics.AUPRO
init_args:
compute_on_cpu: true
image:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
metrics:
pixel:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
image:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from itertools import chain
from pathlib import Path
from collections import OrderedDict

import pytest
import pytorch_lightning as pl
from omegaconf import OmegaConf
import torch

from anomalib.models.components import AnomalyModule
from anomalib.utils.callbacks.metrics_configuration import MetricsConfigurationCallback
Expand Down Expand Up @@ -61,3 +64,54 @@ def test_metric_collection_configuration_callback(config_from_yaml):
assert isinstance(
dummy_anomaly_module.pixel_metrics, AnomalibMetricCollection
), f"{dummy_anomaly_module.pixel_metrics}"


@pytest.mark.parametrize(
["ori_config_from_yaml", "saved_config_from_yaml"],
[("data/config-good-02.yaml", "data/config-good-02-serialized.yaml")],
)
def test_metric_collection_configuration_deserialzation_callback(ori_config_from_yaml, saved_config_from_yaml):
"""Test if metrics are properly instantiated during deserialzation."""

ori_config_from_yaml_res = OmegaConf.load(Path(__file__).parent / ori_config_from_yaml)
saved_config_from_yaml_res = OmegaConf.load(Path(__file__).parent / saved_config_from_yaml)
callback = MetricsConfigurationCallback(
task="segmentation",
image_metrics=ori_config_from_yaml_res.metrics.image,
pixel_metrics=ori_config_from_yaml_res.metrics.pixel,
)

dummy_logger = DummyLogger()
dummy_anomaly_module = _DummyAnomalyModule()
trainer = pl.Trainer(
callbacks=[callback], logger=dummy_logger, enable_checkpointing=False, default_root_dir=dummy_logger.tempdir
)

saved_image_state_dict = OrderedDict(
{
"image_metrics." + k: torch.tensor(1.0)
for k, v in saved_config_from_yaml_res.metrics.image.items()
if v["class_path"].startswith("anomalib.utils.metrics")
}
)
saved_pixel_state_dict = OrderedDict(
{
"pixel_metrics." + k: torch.tensor(1.0)
for k, v in saved_config_from_yaml_res.metrics.pixel.items()
if v["class_path"].startswith("anomalib.utils.metrics")
}
)

final_state_dict = OrderedDict(chain(saved_image_state_dict.items(), saved_pixel_state_dict.items()))

dummy_anomaly_module._load_metrics(final_state_dict)
callback.setup(trainer, dummy_anomaly_module, DummyDataModule())

assert isinstance(
dummy_anomaly_module.image_metrics, AnomalibMetricCollection
), f"{dummy_anomaly_module.image_metrics}"
assert isinstance(
dummy_anomaly_module.pixel_metrics, AnomalibMetricCollection
), f"{dummy_anomaly_module.pixel_metrics}"

assert sorted((list(dummy_anomaly_module.pixel_metrics))) == ["AUPRO", "AUROC", "F1Score"]

0 comments on commit 66532fc

Please sign in to comment.