Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐞 Fix issues when validation and test split modes set to none #1703

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from anomalib.callbacks.thresholding import _ThresholdCallback
from anomalib.callbacks.visualizer import _VisualizationCallback
from anomalib.data import AnomalibDataModule, AnomalibDataset, PredictDataset
from anomalib.data.utils import TestSplitMode
from anomalib.deploy.export import ExportType, export_to_onnx, export_to_openvino, export_to_torch
from anomalib.models import AnomalyModule
from anomalib.utils.normalization import NormalizationMethod
Expand Down Expand Up @@ -633,7 +634,7 @@ def train(
test_dataloaders: EVAL_DATALOADERS | None = None,
datamodule: AnomalibDataModule | None = None,
ckpt_path: str | None = None,
) -> _EVALUATE_OUTPUT:
) -> _EVALUATE_OUTPUT | None:
"""Fits the model and then calls test on it.

Args:
Expand All @@ -650,6 +651,9 @@ def train(
ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path.
Defaults to None.

Returns:
_EVALUATE_OUTPUT | None: A List of dictionaries containing the test results. 1 dict per dataloader.

CLI Usage:
1. you can pick a model, and you can run through the MVTec dataset.
```python
Expand All @@ -671,7 +675,12 @@ def train(
self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule)
else:
self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
self.trainer.test(model, test_dataloaders, ckpt_path=ckpt_path, datamodule=datamodule)

if datamodule is not None and datamodule.test_split_mode == TestSplitMode.NONE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if datamodule is not None and test_dataloaders is not None

logger.info(f"The test_split_mode is set to '{TestSplitMode.NONE}'. Skipping test stage.")
logger.warning(f"Found {len(datamodule.test_data)} images in the test set.")
return None
return self.trainer.test(model, test_dataloaders, ckpt_path=ckpt_path, datamodule=datamodule)

def export(
self,
Expand Down
18 changes: 18 additions & 0 deletions src/anomalib/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from jsonargparse import Path as JSONArgparsePath
from omegaconf import DictConfig, ListConfig, OmegaConf

from anomalib.data.utils import ValSplitMode

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -122,6 +124,7 @@ def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | L
config.results_dir.path = str(project_path)

config = _update_nncf_config(config)
config = _update_val_config(config)

# write the original config for eventual debug (modified config at the end of the function)
(project_path / "config_original.yaml").write_text(to_yaml(config_original))
Expand Down Expand Up @@ -214,6 +217,21 @@ def _update_nncf_config(config: DictConfig | ListConfig) -> DictConfig | ListCon
return config


def _update_val_config(config: DictConfig | ListConfig) -> DictConfig | ListConfig:
"""Skip validation if `val_split_mode` is set to 'none'.

Args:
config (DictConfig | ListConfig): Configurable parameters of the current run.

Returns:
DictConfig | ListConfig: Updated configurable parameters in DictConfig object.
"""
if config.data.init_args.val_split_mode == ValSplitMode.NONE and config.trainer.limit_val_batches != 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that if val_split_mode is None, that validation dataset is not provided separately?

logger.warning("Running without validation set. Setting trainer.limit_val_batches to 0.")
config.trainer.limit_val_batches = 0.0
return config


def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None:
"""Show warnings if any based on the configuration settings.

Expand Down
Loading