Skip to content

Commit

Permalink
Pipeline code cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Blaz Rolih <blaz.rolih@gmail.com>
  • Loading branch information
blaz-r committed Jun 15, 2024
1 parent 9713112 commit dac985f
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 82 deletions.
14 changes: 7 additions & 7 deletions src/anomalib/pipelines/tiled_ensemble/components/merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class MergeJob(Job):
"""Job for merging tile-level predictions into image-level predictions.
Args:
predictions (EnsemblePredictions): object containing ensemble predictions.
tiler (EnsembleTiler): ensemble tiler used for untiling.
predictions (EnsemblePredictions): Object containing ensemble predictions.
tiler (EnsembleTiler): Ensemble tiler used for untiling.
"""

name = "pipeline"
Expand All @@ -39,10 +39,10 @@ def run(self, task_id: int | None = None) -> list[Any]:
"""Run merging job that merges all batches of tile-level predictions into image-level predictions.
Args:
task_id: not used in this case
task_id: Not used in this case.
Returns:
list[Any]: list of merged predictions.
list[Any]: List of merged predictions.
"""
del task_id # not needed here

Expand All @@ -63,7 +63,7 @@ def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
"""Nothing to collect in this job.
Returns:
list[Any]: list of predictions.
list[Any]: List of predictions.
"""
# take the first element as result is list of lists here
return results[0]
Expand Down Expand Up @@ -92,8 +92,8 @@ def generate_jobs(
"""Return a generator producing a single merging job.
Args:
args: tiled ensemble pipeline args.
prev_stage_result (EnsemblePredictions): ensemble predictions from predict step.
args (dict): Tiled ensemble pipeline args.
prev_stage_result (EnsemblePredictions): Ensemble predictions from predict step.
Returns:
Generator[Job, None, None]: MergeJob generator
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tiled ensemble - metrics calculation job."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

Expand All @@ -24,7 +25,11 @@ class MetricsCalculationJob(Job):
"""Job for image and pixel metrics calculation.
Args:
predictions (list[Any]): list of predictions.
accelerator (str): Accelerator (device) to use.
predictions (list[Any]): List of predictions.
root_dir (Path): Root directory to save checkpoints, stats and images.
image_metrics (AnomalibMetricCollection): Collection of all image-level metrics.
pixel_metrics (AnomalibMetricCollection): Collection of all pixel-level metrics.
"""

name = "pipeline"
Expand All @@ -48,7 +53,7 @@ def run(self, task_id: int | None = None) -> dict:
"""Run a job that calculates image and pixel level metrics.
Args:
task_id: not used in this case
task_id: Not used in this case.
Returns:
dict[str, float]: Dictionary containing calculated metric values.
Expand All @@ -63,6 +68,7 @@ def run(self, task_id: int | None = None) -> dict:
if "mask" in data and "anomaly_maps" in data:
self.pixel_metrics.update(data["anomaly_maps"], data["mask"].int())

# compute all metrics on specified accelerator
metrics_dict = {}
for name, metric in self.image_metrics.items():
metric.to(self.accelerator)
Expand Down Expand Up @@ -95,7 +101,7 @@ def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:

@staticmethod
def save(results: GATHERED_RESULTS) -> None:
"""Nothing is saved in this job."""
"""Save metrics values to csv."""
logger.info("Saving metrics to csv.")

# get and remove path from stats dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def run(
"""Run train job that fits the model for given tile location.
Args:
task_id: Passed when job is ran in parallel
task_id: Passed when job is ran in parallel.
Returns:
TiledEnsembleEngine: engine with trained model.
TiledEnsembleEngine: Engine containing trained model.
"""
devices: str | list[int] = "auto"
if task_id is not None:
Expand Down Expand Up @@ -141,7 +141,7 @@ def generate_jobs(
Args:
args (dict): Dict with config passed to training.
prev_stage_result (None): not used here
prev_stage_result (None): Not used here.
"""
del prev_stage_result # Not needed for this job

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class NormalizationJob(Job):
"""Job for normalization of predictions.
Args:
predictions (list[Any]): list of predictions.
root_dir (Path): Root directory to save checkpoints, stats and images.
predictions (list[Any]): List of predictions.
root_dir (Path): Root directory containing statistics needed for normalization.
"""

name = "pipeline"
Expand All @@ -37,13 +37,14 @@ def run(self, task_id: int | None = None) -> list[Any]:
"""Run normalization job which normalizes image, pixel and box scores.
Args:
task_id: not used in this case
task_id: Not used in this case.
Returns:
list[Any]: list of normalized predictions.
list[Any]: List of normalized predictions.
"""
del task_id # not needed here

# load all statistics needed for normalization
stats_path = self.root_dir / "weights" / "lightning" / "stats.json"
with stats_path.open("r") as f:
stats = json.load(f)
Expand All @@ -70,7 +71,7 @@ def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
"""Nothing to collect in this job.
Returns:
list[Any]: list of predictions.
list[Any]: List of predictions.
"""
# take the first element as result is list of lists here
return results[0]
Expand All @@ -84,7 +85,7 @@ class NormalizationJobGenerator(JobGenerator):
"""Generate NormalizationJob.
Args:
root_dir (Path): Root directory to save checkpoints, stats and images.
root_dir (Path): Root directory where statistics are saved.
"""

def __init__(self, root_dir: Path) -> None:
Expand All @@ -104,10 +105,10 @@ def generate_jobs(
Args:
args: not used here.
prev_stage_result (list[Any]): ensemble predictions from previous step.
prev_stage_result (list[Any]): Ensemble predictions from previous step.
Returns:
Generator[Job, None, None]: NormalizationJob generator
Generator[Job, None, None]: NormalizationJob generator.
"""
del args # not needed here

Expand Down
13 changes: 7 additions & 6 deletions src/anomalib/pipelines/tiled_ensemble/components/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
) -> None:
super().__init__()
if engine is None and ckpt_path is None:
msg = "At least one, engine or checkpoint, must be provided to predict job."
msg = "Either engine or checkpoint must be provided to predict job."
raise ValueError(msg)

self.accelerator = accelerator
Expand All @@ -82,10 +82,10 @@ def run(
"""Predict job that predicts the data with specific model for given tile location.
Args:
task_id: Passed when job is ran in parallel
task_id: Passed when job is ran in parallel.
Returns:
list[Any]: list of predictions.
tuple[tuple[int, int], list[Any]]: Tile index, List of predictions.
"""
devices: str | list[int] = "auto"
if task_id is not None:
Expand All @@ -96,7 +96,7 @@ def run(
seed_everything(self.seed)

if self.engine is None:
# in case predict is invoked separately from train job
# in case predict is invoked separately from train job, make new engine instance
self.engine = get_ensemble_engine(
tile_index=self.tile_index,
accelerator=self.accelerator,
Expand All @@ -107,14 +107,15 @@ def run(

predictions = self.engine.predict(model=self.model, dataloaders=self.dataloader, ckpt_path=self.ckpt_path)

# also return tile index as it's needed in collect method
return self.tile_index, predictions

@staticmethod
def collect(results: list[tuple[tuple[int, int], list[Any]]]) -> EnsemblePredictions:
"""Collect predictions from each tile location into the predictions class.
Returns:
EnsemblePredictions: object containing all predictions in form ready for joining.
EnsemblePredictions: Object containing all predictions in form ready for merging.
"""
storage = EnsemblePredictions()

Expand Down Expand Up @@ -188,7 +189,7 @@ def generate_jobs(
dataloader = datamodule.test_dataloader()
if self.data_source == PredictData.VAL:
dataloader = datamodule.val_dataloader()
# TODO: - this is hack to avoid problem in engine:388 - I think if model has transforms
# TODO: - this is tweak to avoid problem in engine:388 - I think if model has transforms
# that should be preferred over dataset transforms?
dataloader.dataset.transform = None

Expand Down
18 changes: 10 additions & 8 deletions src/anomalib/pipelines/tiled_ensemble/components/smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class SmoothingJob(Job):
"""Job for smoothing the area around the tile seam.
Args:
predictions (list[Any]): list of image-level predictions.
accelerator (str): Accelerator used for processing.
predictions (list[Any]): List of image-level predictions.
width_factor (float): Factor multiplied by tile dimension to get the region around seam which will be smoothed.
filter_sigma (float): Sigma of filter used for smoothing the seams.
tiler (EnsembleTiler): Tiler object used to get tile dimension data.
Expand Down Expand Up @@ -57,7 +58,7 @@ def prepare_seam_mask(self) -> torch.Tensor:
"""Prepare boolean mask of regions around the part where tiles seam in ensemble.
Returns:
Tensor: Representation of boolean mask where filtered seams should be used.
torch.Tensor: Representation of boolean mask where filtered seams should be used.
"""
img_h, img_w = self.tiler.image_size
stride_h, stride_w = self.tiler.stride_h, self.tiler.stride_w
Expand Down Expand Up @@ -86,10 +87,10 @@ def run(self, task_id: int | None = None) -> list[Any]:
"""Run smoothing job.
Args:
task_id: not used in this case
task_id: Not used in this case.
Returns:
list[Any]: list of predictions.
list[Any]: List of predictions.
"""
del task_id # not needed here

Expand All @@ -110,7 +111,7 @@ def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
"""Nothing to collect in this job.
Returns:
list[Any]: list of predictions.
list[Any]: List of predictions.
"""
# take the first element as result is list of lists here
return results[0]
Expand All @@ -136,12 +137,13 @@ def generate_jobs(
"""Return a generator producing a single seam smoothing job.
Args:
args: tiled ensemble pipeline args.
prev_stage_result (list[Any]): ensemble predictions from merging step.
args: Tiled ensemble pipeline args.
prev_stage_result (list[Any]): Ensemble predictions from previous step.
Returns:
Generator[Job, None, None]: MergeJob generator
Generator[Job, None, None]: SmoothingJob generator
"""
# tiler is used to determine where seams appear
tiler = get_ensemble_tiler(args)
yield SmoothingJob(
accelerator=args["accelerator"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class StatisticsJob(Job):
"""Job for calculating min, max and threshold statistics for post-processing.
Args:
predictions (list[Any]): list of image-level predictions.
predictions (list[Any]): List of image-level predictions.
root_dir (Path): Root directory to save checkpoints, stats and images.
"""

Expand All @@ -38,10 +38,10 @@ def run(self, task_id: int | None = None) -> dict:
"""Run job that calculates statistics needed in post-processing steps.
Args:
task_id: not used in this case
task_id: Not used in this case
Returns:
dict: statistics dict with min, max and threshold values.
dict: Statistics dict with min, max and threshold values.
"""
del task_id # not needed here

Expand Down Expand Up @@ -111,7 +111,7 @@ class StatisticsJobGenerator(JobGenerator):
"""Generate StatisticsJob.
Args:
root_dir (Path): Root directory to save checkpoints, stats and images.
root_dir (Path): Root directory where statistics file will be saved (in weights folder).
"""

def __init__(self, root_dir: Path) -> None:
Expand All @@ -130,11 +130,11 @@ def generate_jobs(
"""Return a generator producing a single stats calculating job.
Args:
args: not used here.
prev_stage_result (list[Any]): ensemble predictions from previous step.
args: Not used here.
prev_stage_result (list[Any]): Ensemble predictions from previous step.
Returns:
Generator[Job, None, None]: StatisticsJob generator
Generator[Job, None, None]: StatisticsJob generator.
"""
del args # not needed here

Expand Down
15 changes: 8 additions & 7 deletions src/anomalib/pipelines/tiled_ensemble/components/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ThresholdingJob(Job):
"""Job used to threshold predictions, producing labels from scores.
Args:
predictions (list[Any]): list of predictions.
predictions (list[Any]): List of predictions.
image_threshold (float): Threshold used for image-level thresholding.
pixel_threshold (float): Threshold used for pixel-level thresholding.
"""
Expand All @@ -41,10 +41,10 @@ def run(self, task_id: int | None = None) -> list[Any]:
"""Run job that produces prediction labels from scores.
Args:
task_id: not used in this case
task_id: Not used in this case.
Returns:
list[Any]: list of thresholded predictions.
list[Any]: List of thresholded predictions.
"""
del task_id # not needed here

Expand Down Expand Up @@ -74,7 +74,7 @@ def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
"""Nothing to collect in this job.
Returns:
list[Any]: list of predictions.
list[Any]: List of predictions.
"""
# take the first element as result is list of lists here
return results[0]
Expand All @@ -88,7 +88,7 @@ class ThresholdingJobGenerator(JobGenerator):
"""Generate ThresholdingJob.
Args:
root_dir (Path): Root directory containing post processing stats.
root_dir (Path): Root directory containing post-processing stats.
"""

def __init__(self, root_dir: Path) -> None:
Expand All @@ -108,11 +108,12 @@ def generate_jobs(
Args:
args: ensemble run args.
prev_stage_result (list[Any]): ensemble predictions from previous step.
prev_stage_result (list[Any]): Ensemble predictions from previous step.
Returns:
Generator[Job, None, None]: ThresholdingJob generator
Generator[Job, None, None]: ThresholdingJob generator.
"""
# get threshold values base on normalization
image_threshold, pixel_threshold = get_threshold_values(args, self.root_dir)

yield ThresholdingJob(
Expand Down
Loading

0 comments on commit dac985f

Please sign in to comment.