Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stopping with warmup. Remove mandatory background label in semantic segmentation task #3515

Merged
merged 13 commits into from
May 28, 2024
82 changes: 82 additions & 0 deletions src/otx/algo/callbacks/adaptive_early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Callback for early stopping with warmup possibility."""

from __future__ import annotations

from typing import TYPE_CHECKING

from lightning.pytorch.callbacks.early_stopping import EarlyStopping

if TYPE_CHECKING:
import lightning.pytorch as pl


class EarlyStoppingWithWarmup(EarlyStopping):
"""EarlyStoppingWithWarmup callback."""

def __init__(
self,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
check_finite: bool = True,
stopping_threshold: float | None = None,
divergence_threshold: float | None = None,
check_on_train_epoch_end: bool | None = None,
log_rank_zero_only: bool = False,
warmup_iters: int = 100,
warmup_epochs: int = 3,
):
"""EarlyStoppingWithWarmup callback.

Args:
monitor (str): The metric to monitor.
min_delta (float, optional): Minimum change in the monitored quantity
to qualify as an improvement. Defaults to 0.0.
patience (int, optional): Number of epochs with no improvement
after which training will be stopped. Defaults to 3.
verbose (bool, optional): If True, prints messages to stdout. Defaults to False.
mode (str, optional): One of {"min", "max"}. In "min" mode, training will stop when
the quantity monitored has stopped decreasing. In "max" mode,
it will stop when the quantity monitored has stopped increasing. Defaults to "min".
strict (bool, optional): If True, the monitored quantity must improve
according to the mode for it to be considered an improvement.
Defaults to True.
check_finite (bool, optional): If True, check that the monitored quantity is
finite before considering an improvement. Defaults to True.
stopping_threshold (float | None, optional): The threshold to stop training.
Defaults to None.
divergence_threshold (float | None, optional): The threshold for divergence detection.
Defaults to None.
check_on_train_epoch_end (bool | None, optional): If True,
checks the stopping criterion on train_epoch_end. Defaults to None.
log_rank_zero_only (bool, optional): If True, logs should only be printed from rank 0.
Defaults to False.
warmup_iters (int, optional): Number of warmup iterations. Defaults to 100.
warmup_epochs (int, optional): Number of warmup epochs. Defaults to 3.
"""
super().__init__(
monitor=monitor,
min_delta=min_delta,
patience=patience,
verbose=verbose,
mode=mode,
strict=strict,
check_finite=check_finite,
stopping_threshold=stopping_threshold,
divergence_threshold=divergence_threshold,
check_on_train_epoch_end=check_on_train_epoch_end,
log_rank_zero_only=log_rank_zero_only,
)
# two thresholds to have invariant to extra small datasets and larger datasets
self.warmup_iters = warmup_iters
self.warmup_epochs = warmup_epochs

def _should_skip_check(self, trainer: pl.Trainer) -> bool:
warmup_threshold = max(self.warmup_epochs * trainer.num_training_batches, self.warmup_iters)
return super()._should_skip_check(trainer) or trainer.global_step < warmup_threshold

Check warning on line 82 in src/otx/algo/callbacks/adaptive_early_stopping.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/callbacks/adaptive_early_stopping.py#L81-L82

Added lines #L81 - L82 were not covered by tests
14 changes: 14 additions & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,27 @@
stack_images,
to_tv_image,
)

if self.has_polygons and "background" not in [label_name.lower() for label_name in self.label_info.label_names]:
# insert background class at index 0 since polygons represent only objects
self.label_info.label_names.insert(0, "background")

Check warning on line 177 in src/otx/core/data/dataset/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/segmentation.py#L177

Added line #L177 was not covered by tests

self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
)
self.ignore_index = ignore_index

@property
def has_polygons(self) -> bool:
"""Check if the dataset has polygons in annotations."""
for subset in self.dm_subset.subsets().values():
annot_types = set(subset.get_annotated_type())
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
if annot_types & {"polygon", "ellipse"}:
return True

Check warning on line 192 in src/otx/core/data/dataset/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/segmentation.py#L190-L192

Added lines #L190 - L192 were not covered by tests
return False

def _get_item_impl(self, index: int) -> SegDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
Expand Down
11 changes: 6 additions & 5 deletions src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@
used_labels: list[int] = list({ann.label for item in dataset for ann in item.annotations})
if data_format == "ava":
used_labels = [0, *used_labels]
elif data_format == "common_semantic_segmentation_with_subset_dirs":
if 0 in used_labels:
used_labels = [label - 1 for label in used_labels[1:]]
else:
used_labels = [label - 1 for label in used_labels]
if data_format == "common_semantic_segmentation_with_subset_dirs" and len(original_categories) < len(used_labels):
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
msg = (

Check warning on line 76 in src/otx/core/data/pre_filtering.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/pre_filtering.py#L76

Added line #L76 was not covered by tests
"There are labeles mismatch in dataset categories and actuall categories comes from semantic masks."
"Please, check `dataset_meta.json` file."
)
raise ValueError(msg)

Check warning on line 80 in src/otx/core/data/pre_filtering.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/pre_filtering.py#L80

Added line #L80 was not covered by tests
if len(used_labels) == len(original_categories):
return dataset
msg = "There are unused labels in dataset, they will be filtered out before training."
Expand Down
17 changes: 6 additions & 11 deletions src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import json
import warnings
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -284,14 +283,12 @@
ignore_index: int = 255

def __post_init__(self):
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
if not any(word.lower() == "background" for word in self.label_names):
if len(self.label_names) <= 1:
msg = (
"Currently, no background label exists for `label_names`. "
"Segmentation requires a background label. "
"To do this, `Background` is added at index 0 of `label_names`."
"The number of labels must be larger than 1. "
"Please, check dataset labels and add background label in case of binary segmentation."
)
warnings.warn(msg, stacklevel=2)
self.label_names.insert(0, "Background")
raise ValueError(msg)

Check warning on line 291 in src/otx/core/types/label.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/types/label.py#L291

Added line #L291 was not covered by tests

@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
Expand All @@ -307,12 +304,10 @@
)
"""
if num_classes == 1:
label_names = ["Background"]
# binary segmentation
label_names = ["background", "label_0"]
return SegLabelInfo(label_names=label_names, label_groups=[label_names])

# NOTE: It should have "Background" label at the first place.
# To consider it, we need to decrease num_classes by one.
num_classes = num_classes - 1
return super().from_num_classes(num_classes)


Expand Down
4 changes: 3 additions & 1 deletion src/otx/recipe/_base_/train.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
max_epochs: 200
min_epochs: 1
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
monitor: null
mode: max
patience: 10
check_on_train_epoch_end: false
min_delta: 0.001
warmup_iters: 30
warmup_epochs: 3
- class_path: lightning.pytorch.callbacks.RichProgressBar
init_args:
refresh_rate: 1
Expand Down
2 changes: 1 addition & 1 deletion src/otx/recipe/classification/h_label_cls/deit_tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 90
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 4
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 4
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data: ../../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
patience: 3
data:
Expand Down
5 changes: 5 additions & 0 deletions src/otx/recipe/semantic_segmentation/dino_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,8 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]

callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100
9 changes: 6 additions & 3 deletions src/otx/recipe/semantic_segmentation/litehrnet_18.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml

overrides:
max_epochs: 300
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
9 changes: 6 additions & 3 deletions src/otx/recipe/semantic_segmentation/litehrnet_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml

overrides:
max_epochs: 300
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
9 changes: 6 additions & 3 deletions src/otx/recipe/semantic_segmentation/litehrnet_x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml

overrides:
max_epochs: 300
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
8 changes: 6 additions & 2 deletions src/otx/recipe/semantic_segmentation/segnext_b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ engine:

callback_monitor: val/Dice

data: ../_base_/data/mmseg_base.yaml
overrides:
max_epochs: 170
callbacks:
- class_path: otx.algo.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 100

data: ../_base_/data/mmseg_base.yaml
Loading
Loading