Skip to content

Commit

Permalink
Merge main into feature/pimo (#1273)
Browse files Browse the repository at this point in the history
* Configure readthedocs via `.readthedocs.yaml` file (#1229)

* 馃殮 Refactor Benchmarking Script (#1216)

* New printing stuff

* Remove dead code + address codacy issues

* Refactor try/except + log to comet/wandb during runs

* pre-commit error

* third-party configuration

---------

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>

* Update CODEOWNERS

* Enable training with only normal images for MVTec (#1241)

* ignore mask check when dataset has only normal samples

* update changelog

* Revert "馃殮 Refactor Benchmarking Script" (#1239)

Revert "馃殮 Refactor Benchmarking Script (#1216)"

This reverts commit 784767f.

* Update benchmarking notebook (#1242)

* Fix metadata path

* Update benchmarking notebook

* Fix links to model architecture images (#1245)

* Fix links to architecture images

* Change links to raw files

* Wandb unwatch method belongs to experiment, not logger (#1246)

unwatch method belongs to experiment, not logger

* (Minor change) Added the tracer_kwargs to the TorchFXFeatureExtractor class (#1214)

* Added tracer_kwargs to torchfx

* Added tracer_kwargs on docstring

* Replace cdist in Patchcore (#1267)

* Ignore hidden directories when creating `Folder` dataset (#1268)

* Remove `config` from argparse in OpenVINO inference script. (#1257)

* Fix metadata path

* Remove leftover argument

* Update openvino entrypoint script

* Fix EfficientAD number of steps for optimizer lr change? (#1266)

* Fix metadata path

* Fix number of steps

---------

Co-authored-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: Dick Ameln <dick.ameln@intel.com>
Co-authored-by: Bla啪 Rolih <61357777+blaz-r@users.noreply.github.com>
Co-authored-by: Sean Aubin <seanaubin@gmail.com>
Co-authored-by: JoaoGuibs <32060480+JoaoGuibs@users.noreply.github.com>
  • Loading branch information
7 people committed Aug 15, 2023
1 parent 1dd9434 commit f62fde0
Show file tree
Hide file tree
Showing 17 changed files with 58 additions and 25 deletions.
6 changes: 5 additions & 1 deletion src/anomalib/data/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def _prepare_files_labels(
if isinstance(extensions, str):
extensions = (extensions,)

filenames = [f for f in path.glob(r"**/*") if f.suffix in extensions and not f.is_dir()]
filenames = [
f
for f in path.glob("**/*")
if f.suffix in extensions and not f.is_dir() and not any(part.startswith(".") for part in f.parts)
]
if not filenames:
raise RuntimeError(f"Found 0 {path_type} images in {path}")

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/cfa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Coupled-hypersphere-based Feature Adaptation (CFA) localizes anomalies using fea

## Architecture

![Cfa Architecture](../../../docs/source/images/cfa/architecture.png "Cfa Architecture")
![Cfa Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/cfa/architecture.png "Cfa Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/cflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ CFLOW model is based on a conditional normalizing flow framework adopted for ano

## Architecture

![CFlow Architecture](../../../docs/source/images/cflow/architecture.jpg "CFlow Architecture")
![CFlow Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/cflow/architecture.jpg "CFlow Architecture")

## Usage

Expand Down
14 changes: 12 additions & 2 deletions src/anomalib/models/components/feature_extractors/torchfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class TorchFXFeatureExtractor(nn.Module):
path for custom models.
requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should
set ``requires_grad`` to ``True``. Default is ``False``.
tracer_kwargs (dict | None): a dictionary of keyword arguments for NodePathTracer (which passes them onto
it's parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic
modules, by passing a list of `leaf_modules` as one of the `tracer_kwargs`.
Example:
With torchvision models:
Expand Down Expand Up @@ -91,6 +94,7 @@ def __init__(
return_nodes: list[str],
weights: str | WeightsEnum | None = None,
requires_grad: bool = False,
tracer_kwargs: dict | None = None,
):
super().__init__()
if isinstance(backbone, dict):
Expand All @@ -102,14 +106,17 @@ def __init__(
f"backbone needs to be of type str | BackboneParams | dict | nn.Module, but was type {type(backbone)}"
)

self.feature_extractor = self.initialize_feature_extractor(backbone, return_nodes, weights, requires_grad)
self.feature_extractor = self.initialize_feature_extractor(
backbone, return_nodes, weights, requires_grad, tracer_kwargs
)

def initialize_feature_extractor(
self,
backbone: BackboneParams | nn.Module,
return_nodes: list[str],
weights: str | WeightsEnum | None = None,
requires_grad: bool = False,
tracer_kwargs: dict | None = None,
) -> GraphModule:
"""Extract features from a CNN.
Expand All @@ -125,6 +132,9 @@ class can be provided and it will try to load the weights from the provided weig
path for custom models.
requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should
set ``requires_grad`` to ``True``. Default is ``False``.
tracer_kwargs (dict | None): a dictionary of keyword arguments for NodePathTracer (which passes them onto
it's parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic
modules, by passing a list of `leaf_modules` as one of the `tracer_kwargs`.
Returns:
Feature Extractor based on TorchFX.
Expand All @@ -148,7 +158,7 @@ class can be provided and it will try to load the weights from the provided weig
model_weights = model_weights["state_dict"]
backbone_model.load_state_dict(model_weights)

feature_extractor = create_feature_extractor(backbone_model, return_nodes)
feature_extractor = create_feature_extractor(backbone_model, return_nodes, tracer_kwargs=tracer_kwargs)

if not requires_grad:
feature_extractor.eval()
Expand Down
6 changes: 3 additions & 3 deletions src/anomalib/models/csflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ The anomaly score for each local position $(i,j)$ of the feature map $y^s$ at sc

## Architecture

![CS-Flow Architecture](../../../docs/source/images/cs_flow/architecture1.jpg "CS-Flow Architecture")
![CS-Flow Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/cs_flow/architecture1.jpg "CS-Flow Architecture")

![Architecture of a Coupling Block](../../../docs/source/images/cs_flow/architecture2.jpg "Architecture of a Coupling Block")
![Architecture of a Coupling Block](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/cs_flow/architecture2.jpg "Architecture of a Coupling Block")

![Architecture of network predicting scale and shift parameters.](../../../docs/source/images/cs_flow/architecture3.jpg "Architecture of network predicting scale and shift parameters.")
![Architecture of network predicting scale and shift parameters.](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/cs_flow/architecture3.jpg "Architecture of network predicting scale and shift parameters.")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/draem/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ For optimal results, DRAEM requires specifying the path to a folder of image dat

## Architecture

![DRAEM Architecture](../../../docs/source/images/draem/architecture.png "DRAEM Architecture")
![DRAEM Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/draem/architecture.png "DRAEM Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def configure_optimizers(self) -> optim.Optimizer:
lr=self.lr,
weight_decay=self.weight_decay,
)
num_steps = max(
num_steps = min(
self.trainer.max_steps, self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader())
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * num_steps), gamma=0.1)
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/fastflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ FastFlow is a two-dimensional normalizing flow-based probability distribution es

## Architecture

![FastFlow Architecture](../../../docs/source/images/fastflow/architecture.jpg "FastFlow Architecture")
![FastFlow Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/fastflow/architecture.jpg "FastFlow Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/ganomaly/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The key idea here is that, during inference, when an anomalous image is passed t

## Architecture

![GANomaly Architecture](../../../docs/source/images/ganomaly/architecture.jpg "GANomaly Architecture")
![GANomaly Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/ganomaly/architecture.jpg "GANomaly Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/padim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ During inference, Mahalanobis distance is used to score each patch position of t

## Architecture

![PaDiM Architecture](../../../docs/source/images/padim/architecture.jpg "PaDiM Architecture")
![PaDiM Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/padim/architecture.jpg "PaDiM Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ During inference this memory bank is coreset subsampled. Coreset subsampling gen

## Architecture

![PatchCore Architecture](../../../../docs/source/images/patchcore/architecture.jpg "PatchCore Architecture")
![PatchCore Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/patchcore/architecture.jpg "PatchCore Architecture")

## Usage

Expand Down
26 changes: 24 additions & 2 deletions src/anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,28 @@ def subsample_embedding(self, embedding: Tensor, sampling_ratio: float) -> None:
coreset = sampler.sample_coreset()
self.memory_bank = coreset

@staticmethod
def euclidean_dist(x: Tensor, y: Tensor) -> Tensor:
"""
Calculates pair-wise distance between row vectors in x and those in y.
Replaces torch cdist with p=2, as cdist is not properly exported to onnx and openvino format.
Resulting matrix is indexed by x vectors in rows and y vectors in columns.
Args:
x: input tensor 1
y: input tensor 2
Returns:
Matrix of distances between row vectors in x and y.
"""
x_norm = x.pow(2).sum(dim=-1, keepdim=True) # |x|
y_norm = y.pow(2).sum(dim=-1, keepdim=True) # |y|
# row distance can be rewritten as sqrt(|x| - 2 * x @ y.T + |y|.T)
res = x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1)
res = res.clamp_min_(0).sqrt_()
return res

def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> tuple[Tensor, Tensor]:
"""Nearest Neighbours using brute force method and euclidean norm.
Expand All @@ -153,7 +175,7 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> tuple[Tensor
Tensor: Patch scores.
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
distances = self.euclidean_dist(embedding, self.memory_bank)
if n_neighbors == 1:
# when n_neighbors is 1, speed up computation by using min instead of topk
patch_scores, locations = distances.min(1)
Expand Down Expand Up @@ -188,7 +210,7 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
# indices of N_b(m^*) in the paper
_, support_samples = self.nearest_neighbors(nn_sample, n_neighbors=self.num_neighbors)
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples], p=2.0)
distances = self.euclidean_dist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples])
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze(1), 1))[..., 0]
# 6. Apply the weight factor to the score
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/reverse_distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ During testing, a similar step is followed but this time the cosine distance bet

## Architecture

![Anomaly Detection via Reverse Distillation from One-Class Embedding Architecture](../../../docs/source/images/reverse_distillation/architecture.png "Reverse Distillation Architecture")
![Anomaly Detection via Reverse Distillation from One-Class Embedding Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/architecture.png "Reverse Distillation Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/stfpm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ During inference, the feature pyramids of teacher and student networks are compa

## Architecture

![STFPM Architecture](../../../docs/source/images/stfpm/architecture.jpg "STFPM Architecture")
![STFPM Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/stfpm/architecture.jpg "STFPM Architecture")

## Usage

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/utils/callbacks/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if isinstance(logger, (AnomalibCometLogger, AnomalibTensorBoardLogger)):
logger.log_graph(pl_module, input_array=torch.ones((1, 3, 256, 256)))
elif isinstance(logger, AnomalibWandbLogger):
logger.unwatch(pl_module) # type: ignore
logger.experiment.unwatch(pl_module) # type: ignore
2 changes: 0 additions & 2 deletions tests/pre_merge/tools/test_openvino_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def test_openvino_inference(

arguments = get_parser().parse_args(
[
"--config",
"src/anomalib/models/padim/config.yaml",
"--weights",
project_path + "/weights/openvino/model.bin",
"--metadata",
Expand Down
7 changes: 3 additions & 4 deletions tools/inference/openvino_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Anomalib Inferencer Script.
"""Anomalib OpenVINO Inferencer Script.
This script performs inference by reading a model config file from
command line, and show the visualization results.
This script performs OpenVINO inference by reading a model from
file system, and show the visualization results.
"""

# Copyright (C) 2022 Intel Corporation
Expand All @@ -23,7 +23,6 @@ def get_parser() -> ArgumentParser:
ArgumentParser: The parser object.
"""
parser = ArgumentParser()
parser.add_argument("--config", type=Path, required=True, help="Path to a config file")
parser.add_argument("--weights", type=Path, required=True, help="Path to model weights")
parser.add_argument("--metadata", type=Path, required=True, help="Path to a JSON file containing the metadata.")
parser.add_argument("--input", type=Path, required=True, help="Path to an image to infer.")
Expand Down

0 comments on commit f62fde0

Please sign in to comment.