Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding U-Flow method #1415

Merged
merged 26 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d65dab2
Added uflow model
mtailanian Oct 15, 2023
b8a07eb
Added documentation (README) for uflow model
mtailanian Oct 15, 2023
81c049c
Added uflow to the list of available models, and main README updated
mtailanian Oct 15, 2023
1f00a64
Merge branch 'main' into mt_uflow
mtailanian Oct 15, 2023
b0eb956
Added missing images for the documentation
mtailanian Oct 17, 2023
0afb726
Update src/anomalib/models/uflow/anomaly_map.py
mtailanian Oct 17, 2023
6bf1d16
Update src/anomalib/models/uflow/anomaly_map.py
mtailanian Oct 17, 2023
6e1fe6c
Update src/anomalib/models/uflow/feature_extraction.py
mtailanian Oct 17, 2023
13a1b05
Update src/anomalib/models/uflow/torch_model.py
mtailanian Oct 17, 2023
22c3ebb
Merge branch 'mt_uflow' of github.com:mtailanian/anomalib into mt_uflow
mtailanian Oct 17, 2023
9fe38cb
Added uflow to the reference guide in docs
mtailanian Oct 21, 2023
09bdf61
Added uflow to the pre-merge tests
mtailanian Oct 21, 2023
4ca19fc
removed the _step function, and merged the code with training_step
mtailanian Oct 21, 2023
858c589
added as a comment the values used in the paper
mtailanian Oct 21, 2023
04ce7e5
re-factorized feature extractors to use the TimmFeatureExtractor class
mtailanian Oct 21, 2023
0398ffb
added annotations for some functions, where the flow graph is created
mtailanian Oct 21, 2023
33d771d
updated readme to fix images loading
mtailanian Oct 22, 2023
c302a73
Added link in the README to the original code for reproducing the res…
mtailanian Oct 22, 2023
4161fce
Merge branch 'main' into mt_uflow
mtailanian Oct 22, 2023
50d15db
Removed unused kwargs
mtailanian Nov 1, 2023
9dbab19
Added docstrigs with args explanations to UFlow classes
mtailanian Nov 1, 2023
3e4c610
Added models in a github release, and linked here
mtailanian Nov 1, 2023
9eb4856
Merge branch 'main' into mt_uflow
samet-akcay Nov 3, 2023
c3055e4
Passing all pre-commit checks
mtailanian Nov 6, 2023
573bc87
Merge branch 'mt_uflow' of github.com:mtailanian/anomalib into mt_uflow
mtailanian Nov 6, 2023
bf7bc82
Changed freia's AllInOneBlock by Anomalib's version, and converted th…
mtailanian Nov 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ where the currently available models are:
- [PatchCore](src/anomalib/models/patchcore)
- [Reverse Distillation](src/anomalib/models/reverse_distillation)
- [STFPM](src/anomalib/models/stfpm)
- [UFlow](src/anomalib/models/uflow)

## Feature extraction & (pre-trained) backbones

Expand Down
Binary file added docs/source/images/uflow/diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/uflow/iou.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/uflow/more-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/uflow/pixel-aupro.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/uflow/pixel-auroc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/uflow/teaser.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/reference_guide/algorithms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Algorithms
patchcore
reverse_distillation
stfpm
uflow


Feature extraction & (pre-trained) backbones
Expand Down
44 changes: 44 additions & 0 deletions docs/source/reference_guide/algorithms/uflow.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
U-Flow
---------

This is the implementation of the `U-Flow <https://arxiv.org/abs/2211.12353>`_ paper.

Model Type: Segmentation

Description
***********

U-Flow is a U-Shaped normalizing flow-based probability distribution estimator.
The method consists of three phases.
(1) Multi-scale feature extraction: a rich multi-scale representation is obtained with MSCaiT, by combining pre-trained image Transformers acting at different image scales. It can also be used any other feature extractor, such as ResNet.
(2) U-shaped Normalizing Flow: by adapting the widely used U-like architecture to NFs, a fully invertible architecture is designed. This architecture is capable of merging the information from different scales while ensuring independence both intra- and inter-scales. To make it fully invertible, split and invertible up-sampling operations are used.
(3) Anomaly score and segmentation computation: besides generating the anomaly map based on the likelihood of test data, we also propose to adapt the a contrario framework to obtain an automatic threshold by controlling the allowed number of false alarms.

Architecture
************

.. image:: ../../images/uflow/diagram.png
:alt: U-Flow Architecture

Usage
*****

.. code-block:: bash

$ python tools/train.py --model uflow


.. automodule:: anomalib.models.uflow.torch_model
:members:
:undoc-members:
:show-inheritance:

.. automodule:: anomalib.models.uflow.lightning_model
:members:
:undoc-members:
:show-inheritance:

.. automodule:: anomalib.models.uflow.anomaly_map
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from anomalib.models.reverse_distillation import ReverseDistillation
from anomalib.models.rkde import Rkde
from anomalib.models.stfpm import Stfpm
from anomalib.models.uflow import Uflow

__all__ = [
"Cfa",
Expand All @@ -46,6 +47,7 @@
"ReverseDistillation",
"Rkde",
"Stfpm",
"Uflow",
"AiVad",
"EfficientAd",
]
Expand Down
128 changes: 128 additions & 0 deletions src/anomalib/models/uflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold

[//]: # "This is the implementation of the [U-Flow](https://arxiv.org/abs/2211.12353) paper, based on the [original code](https://www.github.com/mtailanian/uflow)"

This is the implementation of the [U-Flow](https://www.researchsquare.com/article/rs-3367286/latest) paper, based on the [original code](https://www.github.com/mtailanian/uflow)

![U-Flow Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/diagram.png "U-Flow Architecture")

## Abstract

_In this work we propose a one-class self-supervised method for anomaly segmentation in images, that benefits both from a modern machine learning approach and a more classic statistical detection theory.
The method consists of three phases. First, features are extracted using a multi-scale image Transformer architecture. Then, these features are fed into a U-shaped Normalizing Flow that lays the theoretical foundations for the last phase, which computes a pixel-level anomaly map and performs a segmentation based on the a contrario framework.
This multiple-hypothesis testing strategy permits the derivation of robust automatic detection thresholds, which are crucial in real-world applications where an operational point is needed.
The segmentation results are evaluated using the Intersection over Union (IoU) metric, and for assessing the generated anomaly maps we report the area under the Receiver Operating Characteristic curve (AUROC), and the area under the per-region-overlap curve (AUPRO).
Extensive experimentation in various datasets shows that the proposed approach produces state-of-the-art results for all metrics and all datasets, ranking first in most MvTec-AD categories, with a mean pixel-level AUROC of 98.74%._

![Teaser image](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/teaser.png)

## Localization results

### Pixel AUROC over MVTec-AD Dataset

![Pixel-AUROC results](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/pixel-auroc.png "Pixel-AUROC results")

### Pixel AUPRO over MVTec-AD Dataset

![Pixel-AUPRO results](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/pixel-aupro.png "Pixel-AUPRO results")

## Segmentation results (IoU) with threshold log(NFA)=0

This paper also proposes a method to automatically compute the threshold using the a contrario framework. All results below are obtained with the threshold log(NFA)=0.
In the default code here, for the sake of comparison with all the other methods of the library, the segmentation is done computing the threshold over the anomaly map at train time.
Nevertheless, the code for computing the segmentation mask with the NFA criterion is included in the `src/anomalib/models/uflow/anomaly_map.py`.

![IoU results](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/iou.png "IoU results")

## Results over other datasets

![Results over other datasets](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/more-results.png "Results over other datasets")

## Benchmarking

Note that the proposed method uses the MCait Feature Extractor, which has an input size of 448x448. In the benchmarking, a size of 256x256 is used for all methods, and therefore the results may differ from those reported. In order to exactly reproduce all results, the reader can refer to the original code (see [here](https://www.github.com/mtailanian/uflow), where the configs used and even the trained checkpoints can be downloaded from [this release](https://github.com/mtailanian/uflow/releases/tag/trained-mvtec-models).

## Reproducing paper's results

Using the default parameters of the config file (`src/anomalib/models/uflow/config.yaml`), the results obtained are very close to the ones reported in the paper:

bottle: 97.98, cable: 98.17, capsule: 98.95, carpet: 99.45, grid: 98.19, hazelnut: 99.01, leather: 99.41, metal_nut: 98.19, pill: 99.15, screw: 99.25, tile: 96.93, toothbrush: 98.97, transistor: 96.70, wood: 96.87, zipper: 97.92

In order to obtain the same exact results, although the architecture parameters stays always the same, the following values for the learning rate and batch size should be used (please refer to the [original code](https://www.github.com/mtailanian/uflow) for more details, where the used configs are available in the source code ([here](https://github.com/mtailanian/uflow/tree/main/configs)), and trained checkpoints are available in [this release](https://github.com/mtailanian/uflow/releases/tag/trained-mvtec-models)):

## Usage

`python tools/train.py --model uflow`

## Download data

### MVTec

https://www.mvtec.com/company/research/datasets/mvtec-ad

### Bean Tech

https://paperswithcode.com/dataset/btad

### LGG MRI

https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation

### ShanghaiTech Campus

https://svip-lab.github.io/dataset/campus_dataset.html

## [Optional] Download pre-trained models

Pre-trained models can be found in [this release](https://github.com/mtailanian/uflow/tree/main/configs), or can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1W1rE0mu4Lv3uWHA5GZigmvVNlBVHqTv_?usp=sharing)

For an easier way of downloading them, please refer to the `README.md` from the [original code](https://www.github.com/mtailanian/uflow)

For reproducing the exact results from the paper, different learning rates and batch sizes are to be used for each category. You can find the exact values in the `configs` folder, following the [previous link](https://drive.google.com/drive/folders/1W1rE0mu4Lv3uWHA5GZigmvVNlBVHqTv_?usp=sharing).
mtailanian marked this conversation as resolved.
Show resolved Hide resolved

## A note on sizes at different points

Input

```text
- Scale 1: [3, 448, 448]
- Scale 2: [3, 224, 224]
```

MS-Cait outputs

```text
- Scale 1: [768, 28, 28]
- Scale 2: [384, 14, 14]
```

Normalizing Flow outputs

```text
- Scale 1: [816, 28, 28] --> 816 = 768 + 384 / 2 / 4
- Scale 2: [192, 14, 14] --> 192 = 384 / 2
```

`/ 2` corresponds to the split, and `/ 4` to the invertible upsample.

## Example results

### Anomalies

#### MVTec

![MVTec results - anomalies](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-mvtec-anomalies.jpg "MVTec results - anomalies")

#### BeanTech, LGG MRI, STC

![BeanTech, LGG MRI, STC results - anomalies](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-others-anomalies.jpg "BeanTech, LGG MRI, STC results - anomalies")

### Normal images

#### MVTec

![MVTec results - normal](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-mvtec-good.jpg "MVTec results - normal")

#### BeanTech, LGG MRI, STC

![BeanTech, LGG MRI, STC results - normal](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-others-good.jpg "BeanTech, LGG MRI, STC results - normal")
8 changes: 8 additions & 0 deletions src/anomalib/models/uflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import Uflow, UflowLightning

__all__ = ["Uflow", "UflowLightning"]
166 changes: 166 additions & 0 deletions src/anomalib/models/uflow/anomaly_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""UFlow Anomaly Map Generator Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import List

import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F
from mpmath import binomial, mp
from omegaconf import ListConfig
from scipy import integrate
from torch import Tensor, nn

mp.dps = 15 # Set precision for NFA computation (in case of high_precision=True)


class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap and segmentation."""

def __init__(self, input_size: ListConfig | tuple) -> None:
super().__init__()
self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size)

def forward(self, latent_variables: list[Tensor]) -> Tensor:
return self.compute_anomaly_map(latent_variables)

def compute_anomaly_map(self, latent_variables: list[Tensor]) -> Tensor:
"""
Generate a likelihood-based anomaly map, from latent variables.
Args:
latent_variables: List of latent variables from the UFlow model. Each element is a tensor of shape
(N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height and
width of the latent variables, respectively, for each scale l.

Returns:
Final Anomaly Map. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and
width of the input image, respectively.
"""

likelihoods = []
for z in latent_variables:
# Mean prob by scale. Likelihood is actually with sum instead of mean. Using mean to avoid numerical issues.
# Also, this way all scales have the same weight, and it does not depend on the number of channels
log_prob_i = -torch.mean(z**2, dim=1, keepdim=True) * 0.5
prob_i = torch.exp(log_prob_i)
likelihoods.append(
F.interpolate(
prob_i,
size=self.input_size,
mode="bilinear",
align_corners=False,
)
)
anomaly_map = 1 - torch.mean(torch.stack(likelihoods, dim=-1), dim=-1)
return anomaly_map

def compute_anomaly_mask(
self,
z: List[torch.Tensor],
win_size: int = 7,
binomial_probability_thr: float = 0.5,
high_precision: bool = False,
):
"""
This method is not used in the basic functionality of training and testing. It is a bit slow, so we decided to
leave it as an option for the user. It is included as it is part of the U-Flow paper, and can be called
separately if an unsupervised anomaly segmentation is needed.

Generate an anomaly mask, from latent variables. It is based on the NFA (Number of False Alarms) method, which
is a statistical method to detect anomalies. The NFA is computed as the log of the probability of the null
hypothesis, which is that all pixels are normal. First, we compute a list of candidate pixels, with
suspiciously high values of z^2, by applying a binomial test to each pixel, looking at a window around it.
Then, to compute the NFA values (actually the log-NFA), we evaluate how probable is that a pixel belongs to the
normal distribution. The null-hypothesis is that under normality assumptions, all candidate pixels are uniformly
distributed. Then, the detection is based on the concentration of candidate pixels.

Args:
z: List of latent variables from the UFlow model. Each element is a tensor of shape
(N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height and
width of the latent variables, respectively, for each scale l.
win_size: Window size for the binomial test.
binomial_probability_thr: Probability threshold for the binomial test.
high_precision: Whether to use high precision for the binomial test.

Returns:
Anomaly mask. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and
width of the input image, respectively.
"""
log_prob_l = [
self.binomial_test(zi, win_size / (2**scale), binomial_probability_thr, high_precision)
for scale, zi in enumerate(z)
]

log_prob_l_up = torch.cat(
[F.interpolate(lpl, size=self.input_size, mode="bicubic", align_corners=True) for lpl in log_prob_l], dim=1
)

log_prob = torch.sum(log_prob_l_up, dim=1, keepdim=True)

log_number_of_tests = torch.log10(torch.sum(torch.tensor([zi.shape[-2] * zi.shape[-1] for zi in z])))
log_nfa = log_number_of_tests + log_prob

anomaly_score = -log_nfa
anomaly_mask = anomaly_score < 0

return anomaly_mask

@staticmethod
def binomial_test(z: torch.Tensor, win, probability_thr: float, high_precision: bool = False) -> torch.Tensor:

Check notice on line 114 in src/anomalib/models/uflow/anomaly_map.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/anomalib/models/uflow/anomaly_map.py#L114

Too many local variables (22/15) (too-many-locals)
"""
The binomial test applied to validate or reject the null hypothesis that the pixel is normal. The null
hypothesis is that the pixel is normal, and the alternative hypothesis is that the pixel is anomalous. The
binomial test is applied to a window around the pixel, and the number of pixels in the window that are
anomalous is compared to the number of pixels that are expected to be anomalous under the null hypothesis.
Args:
z: Latent variable from the UFlow model. Tensor of shape (N, Cl, Hl, Wl), where N is the batch size, Cl is
the number of channels, and Hl and Wl are the height and width of the latent variables, respectively.
win: Window size for the binomial test.
probability_thr: Probability threshold for the binomial test.
high_precision: Whether to use high precision for the binomial test.

Returns:
Log of the probability of the null hypothesis.

"""
tau = st.chi2.ppf(probability_thr, 1)
half_win = np.max([int(win // 2), 1])

n_chann = z.shape[1]

# Candidates
z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect").detach().cpu()
z2_unfold_h = z2.unfold(-2, 2 * half_win + 1, 1)
z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1).numpy()
observed_candidates_k = np.sum(z2_unfold_hw >= tau, axis=(-2, -1))

# All volume together
observed_candidates = np.sum(observed_candidates_k, axis=1, keepdims=True)
x = observed_candidates / n_chann
n = int((2 * half_win + 1) ** 2)

# Low precision
if not high_precision:
log_prob = torch.tensor(st.binom.logsf(x, n, 1 - probability_thr) / np.log(10))
# High precision - good and slow
else:
to_mp = np.frompyfunc(mp.mpf, 1, 1)
mpn = mp.mpf(n)
mpp = probability_thr

def binomial_density(k):
return binomial(mpn, to_mp(k)) * (1 - mpp) ** k * mpp ** (mpn - k)

def integral(xx):
return integrate.quad(binomial_density, xx, n)[0]

integral_array = np.vectorize(integral)
prob = integral_array(x)
log_prob = torch.tensor(np.log10(prob))

return log_prob
Loading
Loading