Skip to content

Commit

Permalink
Add early stopping with warmup. Remove mandatory background label in …
Browse files Browse the repository at this point in the history
…semantic segmentation task (#3515)

* add adaptive early stopping. Remove background label

* add docs

* add has_polygons property

* lower register

* fix unit tests

* fix segm dataset

* fix early stopping in configs

* minor

* reply comments

* added raise error
  • Loading branch information
kprokofi committed May 28, 2024
1 parent 0403503 commit 80acb86
Show file tree
Hide file tree
Showing 35 changed files with 178 additions and 56 deletions.
82 changes: 82 additions & 0 deletions src/otx/algo/callbacks/adaptive_early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Callback for early stopping with warmup possibility."""

from __future__ import annotations

from typing import TYPE_CHECKING

from lightning.pytorch.callbacks.early_stopping import EarlyStopping

if TYPE_CHECKING:
import lightning.pytorch as pl


class EarlyStoppingWithWarmup(EarlyStopping):
"""EarlyStoppingWithWarmup callback."""

def __init__(
self,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
check_finite: bool = True,
stopping_threshold: float | None = None,
divergence_threshold: float | None = None,
check_on_train_epoch_end: bool | None = None,
log_rank_zero_only: bool = False,
warmup_iters: int = 100,
warmup_epochs: int = 3,
):
"""EarlyStoppingWithWarmup callback.
Args:
monitor (str): The metric to monitor.
min_delta (float, optional): Minimum change in the monitored quantity
to qualify as an improvement. Defaults to 0.0.
patience (int, optional): Number of epochs with no improvement
after which training will be stopped. Defaults to 3.
verbose (bool, optional): If True, prints messages to stdout. Defaults to False.
mode (str, optional): One of {"min", "max"}. In "min" mode, training will stop when
the quantity monitored has stopped decreasing. In "max" mode,
it will stop when the quantity monitored has stopped increasing. Defaults to "min".
strict (bool, optional): If True, the monitored quantity must improve
according to the mode for it to be considered an improvement.
Defaults to True.
check_finite (bool, optional): If True, check that the monitored quantity is
finite before considering an improvement. Defaults to True.
stopping_threshold (float | None, optional): The threshold to stop training.
Defaults to None.
divergence_threshold (float | None, optional): The threshold for divergence detection.
Defaults to None.
check_on_train_epoch_end (bool | None, optional): If True,
checks the stopping criterion on train_epoch_end. Defaults to None.
log_rank_zero_only (bool, optional): If True, logs should only be printed from rank 0.
Defaults to False.
warmup_iters (int, optional): Number of warmup iterations. Defaults to 100.
warmup_epochs (int, optional): Number of warmup epochs. Defaults to 3.
"""
super().__init__(
monitor=monitor,
min_delta=min_delta,
patience=patience,
verbose=verbose,
mode=mode,
strict=strict,
check_finite=check_finite,
stopping_threshold=stopping_threshold,
divergence_threshold=divergence_threshold,
check_on_train_epoch_end=check_on_train_epoch_end,
log_rank_zero_only=log_rank_zero_only,
)
# two thresholds to have invariant to extra small datasets and larger datasets
self.warmup_iters = warmup_iters
self.warmup_epochs = warmup_epochs

def _should_skip_check(self, trainer: pl.Trainer) -> bool:
warmup_threshold = max(self.warmup_epochs * trainer.num_training_batches, self.warmup_iters)
return super()._should_skip_check(trainer) or trainer.global_step < warmup_threshold
14 changes: 14 additions & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,27 @@ def __init__(
stack_images,
to_tv_image,
)

if self.has_polygons and "background" not in [label_name.lower() for label_name in self.label_info.label_names]:
# insert background class at index 0 since polygons represent only objects
self.label_info.label_names.insert(0, "background")

self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
)
self.ignore_index = ignore_index

@property
def has_polygons(self) -> bool:
"""Check if the dataset has polygons in annotations."""
for subset in self.dm_subset.subsets().values():
annot_types = set(subset.get_annotated_type())
if annot_types & {"polygon", "ellipse"}:
return True
return False

def _get_item_impl(self, index: int) -> SegDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
Expand Down
11 changes: 6 additions & 5 deletions src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def remove_unused_labels(dataset: DmDataset, data_format: str) -> DmDataset:
used_labels: list[int] = list({ann.label for item in dataset for ann in item.annotations})
if data_format == "ava":
used_labels = [0, *used_labels]
elif data_format == "common_semantic_segmentation_with_subset_dirs":
if 0 in used_labels:
used_labels = [label - 1 for label in used_labels[1:]]
else:
used_labels = [label - 1 for label in used_labels]
if data_format == "common_semantic_segmentation_with_subset_dirs" and len(original_categories) < len(used_labels):
msg = (
"There are labeles mismatch in dataset categories and actuall categories comes from semantic masks."
"Please, check `dataset_meta.json` file."
)
raise ValueError(msg)
if len(used_labels) == len(original_categories):
return dataset
msg = "There are unused labels in dataset, they will be filtered out before training."
Expand Down
17 changes: 6 additions & 11 deletions src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import json
import warnings
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -284,14 +283,12 @@ class SegLabelInfo(LabelInfo):
ignore_index: int = 255

def __post_init__(self):
if not any(word.lower() == "background" for word in self.label_names):
if len(self.label_names) <= 1:
msg = (
"Currently, no background label exists for `label_names`. "
"Segmentation requires a background label. "
"To do this, `Background` is added at index 0 of `label_names`."
"The number of labels must be larger than 1. "
"Please, check dataset labels and add background label in case of binary segmentation."
)
warnings.warn(msg, stacklevel=2)
self.label_names.insert(0, "Background")
raise ValueError(msg)

@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
Expand All @@ -307,12 +304,10 @@ def from_num_classes(cls, num_classes: int) -> LabelInfo:
)
"""
if num_classes == 1:
label_names = ["Background"]
# binary segmentation
label_names = ["background", "label_0"]
return SegLabelInfo(label_names=label_names, label_groups=[label_names])

# NOTE: It should have "Background" label at the first place.
# To consider it, we need to decrease num_classes by one.
num_classes = num_classes - 1
return super().from_num_classes(num_classes)


Expand Down
4 changes: 3 additions & 1 deletion src/otx/recipe/_base_/train.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
max_epochs: 200
min_epochs: 1
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
monitor: null
mode: max
patience: 10
check_on_train_epoch_end: false
min_delta: 0.001
warmup_iters: 30
warmup_epochs: 3
- class_path: lightning.pytorch.callbacks.RichProgressBar
init_args:
refresh_rate: 1
Expand Down
2 changes: 1 addition & 1 deletion src/otx/recipe/classification/h_label_cls/deit_tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
2 changes: 1 addition & 1 deletion src/otx/recipe/classification/multi_class_cls/dino_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 4
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 4
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
5 changes: 5 additions & 0 deletions src/otx/recipe/semantic_segmentation/dino_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,8 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]

callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100
9 changes: 6 additions & 3 deletions src/otx/recipe/semantic_segmentation/litehrnet_18.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml

overrides:
max_epochs: 300
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
9 changes: 6 additions & 3 deletions src/otx/recipe/semantic_segmentation/litehrnet_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml

overrides:
max_epochs: 300
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
9 changes: 6 additions & 3 deletions src/otx/recipe/semantic_segmentation/litehrnet_x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml

overrides:
max_epochs: 300
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
8 changes: 6 additions & 2 deletions src/otx/recipe/semantic_segmentation/segnext_b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml
overrides:
max_epochs: 170
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
Loading

0 comments on commit 80acb86

Please sign in to comment.