Skip to content

Commit

Permalink
Refactor Reverse Distillation to match official code (#1389)
Browse files Browse the repository at this point in the history
* Non-mandatory early stopping

* Added conv4 and bn4 to OCBE

* Loss as in the official code (flattened arrays)

* Added comment on how to use torchvision model as an encoder to reproduce results in the paper

* Remove early stop from config, change default anomaly_map_mode to add

* pre-commit fix

* Updated results

* Update src/anomalib/models/reverse_distillation/README.md

Co-authored-by: Samet Akcay <samet.akcay@intel.com>

* Update src/anomalib/models/reverse_distillation/README.md

Co-authored-by: Samet Akcay <samet.akcay@intel.com>

* Update src/anomalib/models/reverse_distillation/README.md

Co-authored-by: Samet Akcay <samet.akcay@intel.com>

* Remove early_stopping

* Update src/anomalib/models/reverse_distillation/lightning_model.py

Co-authored-by: Samet Akcay <samet.akcay@intel.com>

* Easier to read code

---------

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
abc-125 and samet-akcay committed Oct 14, 2023
1 parent e9ded36 commit e85de73
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 83 deletions.
Binary file modified docs/source/images/reverse_distillation/results/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/images/reverse_distillation/results/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/images/reverse_distillation/results/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 13 additions & 60 deletions src/anomalib/models/reverse_distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,79 +20,32 @@ During testing, a similar step is followed but this time the cosine distance bet

## Benchmark

All results gathered with seed `42`.

Note: Early Stopping (with patience 3) was enabled during training.
All results gathered with seed `42`, train batch size `16`.

## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

### Image-Level AUC

| | ResNet 18 | Wide ResNet 50 |
| :--------- | --------: | -------------: |
| Bottle | 0.998 | 0.992 |
| Cable | 0.982 | 0.583 |
| Capsule | 0.864 | 0.78 |
| Carpet | 0.996 | 0.539 |
| Grid | 0.941 | 0.975 |
| Hazelnut | 0.978 | 0.817 |
| Leather | 0.878 | 1 |
| Metal_nut | 0.999 | 0.929 |
| Pill | 0.944 | 0.553 |
| Screw | 0.778 | 0.86 |
| Tile | 0.833 | 0.513 |
| Toothbrush | 0.967 | 0.7 |
| Transistor | 0.928 | 0.829 |
| Wood | 0.989 | 0.993 |
| Zipper | 0.968 | 0.787 |
| Average | 0.936 | 0.79 |
| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.985 | 0.984 | 1.000 | 1.000 | 1.000 | 0.997 | 1.000 | 0.966 | 0.974 | 1.000 | 1.000 | 0.972 | 0.985 | 0.953 | 0.970 | 0.978 |

### Pixel-Level AUC

| | ResNet 18 | Wide ResNet 50 |
| :--------- | --------: | -------------: |
| Bottle | 0.981 | 0.985 |
| Cable | 0.965 | 0.794 |
| Capsule | 0.983 | 0.986 |
| Carpet | 0.989 | 0.99 |
| Grid | 0.964 | 0.99 |
| Hazelnut | 0.988 | 0.983 |
| Leather | 0.984 | 0.995 |
| Metal_nut | 0.971 | 0.979 |
| Pill | 0.975 | 0.977 |
| Screw | 0.987 | 0.989 |
| Tile | 0.867 | 0.953 |
| Toothbrush | 0.99 | 0.979 |
| Transistor | 0.84 | 0.853 |
| Wood | 0.939 | 0.958 |
| Zipper | 0.988 | 0.959 |
| Average | 0.961 | 0.958 |
| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.969 | 0.988 | 0.992 | 0.991 | 0.954 | 0.947 | 0.984 | 0.964 | 0.987 | 0.988 | 0.969 | 0.975 | 0.996 | 0.991 | 0.893 | 0.984 |

### Image F1 Score

| | ResNet 18 | Wide ResNet 50 |
| :--------- | --------: | -------------: |
| Bottle | 0.95 | 0.959 |
| Cable | 0.911 | 0.76 |
| Capsule | 0.933 | 0.905 |
| Carpet | 0.965 | 0.864 |
| Grid | 0.964 | 0.945 |
| Hazelnut | 0.909 | 0.901 |
| Leather | 0.896 | 0.989 |
| Metal_nut | 0.995 | 0.939 |
| Pill | 0.931 | 0.922 |
| Screw | 0.88 | 0.891 |
| Tile | 0.88 | 0.836 |
| Toothbrush | 0.933 | 0.833 |
| Transistor | 0.769 | 0.744 |
| Wood | 0.966 | 0.948 |
| Zipper | 0.944 | 0.926 |
| Average | 0.922 | 0.891 |
| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.976 | 0.977 | 1.000 | 1.000 | 0.994 | 0.992 | 0.984 | 0.930 | 0.982 | 1.000 | 1.000 | 0.967 | 0.963 | 0.952 | 0.927 | 0.975 |

### Sample Results

![Sample Result 1](../../../docs/source/images/reverse_distillation/results/0.png "Sample Result 1")
![Sample Result 1](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/results/0.png "Sample Result 1")

![Sample Result 2](../../../docs/source/images/reverse_distillation/results/1.png "Sample Result 2")
![Sample Result 2](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/results/1.png "Sample Result 2")

![Sample Result 3](../../../docs/source/images/reverse_distillation/results/2.png "Sample Result 3")
![Sample Result 3](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/results/2.png "Sample Result 3")
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def __init__(
self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2)
self.bn3 = norm_layer(256 * block.expansion)

# This is present in the paper but not in the original code. With some initial experiments, removing this leads
# to better results
# self.conv4 = conv1x1(256 * block.expansion * 3, 256 * block.expansion * 3, 1) # x3 as we concatenate 3 layers
# self.bn4 = norm_layer(256 * block.expansion * 3)
# self.conv4 and self.bn4 are from the original code:
# https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/resnet.py#L412
self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1)
self.bn4 = norm_layer(512 * block.expansion)

for module in self.modules():
if isinstance(module, nn.Conv2d):
Expand Down
13 changes: 4 additions & 9 deletions src/anomalib/models/reverse_distillation/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dataset:
path: ./datasets/MVTec
category: bottle
task: segmentation
train_batch_size: 32
train_batch_size: 16
eval_batch_size: 32
inference_batch_size: 32
num_workers: 8
Expand Down Expand Up @@ -35,21 +35,16 @@ model:
- layer1
- layer2
- layer3
early_stopping:
patience: 3
metric: pixel_AUROC
mode: max
beta1: 0.5
beta2: 0.99
beta2: 0.999
normalization_method: min_max # options: [null, min_max, cdf]
anomaly_map_mode: multiply
anomaly_map_mode: add # options: [add, multiply]

metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
method: adaptive #options: [adaptive, manual]
Expand Down Expand Up @@ -85,7 +80,7 @@ trainer:
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 2 # Don't validate before extracting features.
check_val_every_n_epoch: 200 # Don't validate before extracting features.
fast_dev_run: false
accumulate_grad_batches: 1
max_epochs: 200
Expand Down
17 changes: 10 additions & 7 deletions src/anomalib/models/reverse_distillation/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,20 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
self.save_hyperparameters(hparams)

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
"""Configure model-specific non-mandatory callbacks.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure callback method will be
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
callbacks = []
if "early_stopping" in self.hparams.model:
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
callbacks.append(early_stopping)
return callbacks
15 changes: 12 additions & 3 deletions src/anomalib/models/reverse_distillation/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class ReverseDistillationLoss(nn.Module):
def forward(self, encoder_features: list[Tensor], decoder_features: list[Tensor]) -> Tensor:
"""Computes cosine similarity loss based on features from encoder and decoder.
Based on the official code:
https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/main.py#L33C25-L33C25
Calculates loss from flattened arrays of features, see https://github.com/hq-deng/RD4AD/issues/22
Args:
encoder_features (list[Tensor]): List of features extracted from encoder
decoder_features (list[Tensor]): List of features extracted from decoder
Expand All @@ -23,8 +27,13 @@ def forward(self, encoder_features: list[Tensor], decoder_features: list[Tensor]
Tensor: Cosine similarity loss
"""
cos_loss = torch.nn.CosineSimilarity()
losses = list(map(cos_loss, encoder_features, decoder_features))
loss_sum = 0
for loss in losses:
loss_sum += torch.mean(1 - loss) # mean of cosine distance
for encoder_feature, decoder_feature in zip(encoder_features, decoder_features):
loss_sum += torch.mean(
1
- cos_loss(
encoder_feature.view(encoder_feature.shape[0], -1),
decoder_feature.view(decoder_feature.shape[0], -1),
)
)
return loss_sum
3 changes: 3 additions & 0 deletions src/anomalib/models/reverse_distillation/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
class ReverseDistillationModel(nn.Module):
"""Reverse Distillation Model.
To reproduce results in the paper, use torchvision model for the encoder:
self.encoder = torchvision.models.wide_resnet50_2(pretrained=True)
Args:
backbone (str): Name of the backbone used for encoder and decoder
input_size (tuple[int, int]): Size of input image
Expand Down

0 comments on commit e85de73

Please sign in to comment.