Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce AdaptiveRepeatDataHook for the classification #2428

Merged
merged 24 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
160744b
Add OTX Sampler
sungmanc Aug 8, 2023
1254702
Merge branch 'develop' of https://github.com/openvinotoolkit/training…
sungmanc Aug 8, 2023
bc82ee3
Add docs and tests and changes the configurations
sungmanc Aug 9, 2023
a5c53c2
Merge branch 'develop' of https://github.com/openvinotoolkit/training…
sungmanc Aug 9, 2023
4bed99b
Fix precommit
sungmanc Aug 9, 2023
a0d6cb6
Merge branch 'otx-adaptive-sampler' of https://github.com/sungmanc/tr…
sungmanc Aug 9, 2023
2761162
Update changelog and license
sungmanc Aug 9, 2023
e01095f
Update CHANGELOG.md
sungmanc Aug 9, 2023
5c4a474
Make black happy
sungmanc Aug 9, 2023
fb40fea
Fix mis-information for the logger and training metrics
sungmanc Aug 9, 2023
6990aec
Fix unit-test: remove unwrap_dataset
sungmanc Aug 9, 2023
c3193c2
Merge branch 'develop' into otx-adaptive-sampler
sungmanc Aug 9, 2023
5dcd836
Remove mmcls dependency
sungmanc Aug 9, 2023
557b72c
Merge branch 'otx-adaptive-sampler' of https://github.com/sungmanc/tr…
sungmanc Aug 10, 2023
bafd1e6
Fix OTXSampler to enable the YOLOX case: reintroduce unwrap_dataset
sungmanc Aug 10, 2023
30323ae
Make black happy
sungmanc Aug 10, 2023
19f78e2
Only apply to multi-class classification
sungmanc Aug 10, 2023
5c35edb
Update src/otx/algorithms/classification/adapters/mmcls/configurer.py
sungmanc Aug 11, 2023
a496a82
Reflect reviews and add unit-tests for OTXSampler
sungmanc Aug 11, 2023
52e62c5
Fix conflicts
sungmanc Aug 11, 2023
c71cce7
Integrate model.py adaptive hook config --> recipe
sungmanc Aug 16, 2023
5058d29
Merge branch 'develop' into otx-adaptive-sampler
sungmanc Aug 16, 2023
b14a5ab
Make black happy
sungmanc Aug 16, 2023
19a512f
Fix typo
sungmanc Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file.
- Enable FeatureVectorHook to support action tasks(<https://github.com/openvinotoolkit/training_extensions/pull/2408>)
- Add ONNX metadata to detection, instance segmantation, and segmentation models (<https://github.com/openvinotoolkit/training_extensions/pull/2418>)
- Add a new feature to configure input size(<https://github.com/openvinotoolkit/training_extensions/pull/2420>)
- Introduce the OTXSampler and AdaptiveRepeatDataHook to achieve faster training at the small data regime (<https://github.com/openvinotoolkit/training_extensions/pull/2428>)

### Enhancements

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Adaptive Training
==================

Adaptive-training focuses to adjust the number of iterations or interval for the validation to achieve the fast training.
In the small data regime, we don't need to validate the model at every epoch since there are a few iterations at a single epoch.
To handle this, we have implemented two modules named ``AdaptiveTrainingHook`` and ``AdaptiveRepeatDataHook``. Both of them controls the interval of the validation to do faster training.

.. note::
1. ``AdaptiveTrainingHook`` changes the interval of the validation, evaluation and updating learning rate by checking the number of dataset.
2. ``AdaptiveRepeatDataHook`` changes the repeats of the dataset by pathcing the sampler.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Additional Features
models_optimization
hpo
auto_configuration
adaptive_training
xai
noisy_label_detection
fast_data_loading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def configure_task_adapt_hook(self, cfg):

train_data_cfg.classes = self.model_classes

if not cfg.model.get("multilabel", False) and not cfg.model.get("hierarchical", False):
is_multiclass = not cfg.model.get("multilabel", False) and not cfg.model.get("hierarchical", False)
if is_multiclass:
efficient_mode = cfg["task_adapt"].get("efficient_mode", True)
sampler_type = "balanced"
self.configure_loss(cfg)
Expand All @@ -251,6 +252,7 @@ def configure_task_adapt_hook(self, cfg):
sampler_flag = False
else:
sampler_flag = True

# Update Task Adapt Hook
task_adapt_hook = ConfigDict(
type="TaskAdaptHook",
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -260,6 +262,8 @@ def configure_task_adapt_hook(self, cfg):
sampler_flag=sampler_flag,
sampler_type=sampler_type,
efficient_mode=efficient_mode,
use_adaptive_repeat=is_multiclass,
priority="NORMAL",
)
update_or_add_custom_hook(cfg, task_adapt_hook)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ learning_parameters:
warning: This is applied exclusively when early stopping is enabled.
use_adaptive_interval:
affects_outcome_of: TRAINING
default_value: true
default_value: false
description: Depending on the size of iteration per epoch, adaptively update the validation interval and related values.
editable: true
header: Use adaptive validation interval
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _generate_training_metrics(self, learning_curves): # pylint: disable=argume
elif self._hierarchical:
metric_key = "val/MHAcc"
else:
metric_key = "val/accuracy_top-1"
metric_key = "val/accuracy (%)"

# Learning curves
best_acc = -1
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .adaptive_repeat_data_hook import AdaptiveRepeatDataHook
from .adaptive_training_hook import AdaptiveTrainSchedulingHook
from .cancel_hook import CancelInterfaceHook, CancelTrainingHook
from .checkpoint_hook import (
Expand Down Expand Up @@ -53,6 +54,7 @@
from .unbiased_teacher_hook import UnbiasedTeacherHook

__all__ = [
"AdaptiveRepeatDataHook",
"AdaptiveTrainSchedulingHook",
"CancelInterfaceHook",
"CancelTrainingHook",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Adaptive repeat data hook."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmcv.runner import HOOKS, Hook, get_dist_info
from torch.utils.data import DataLoader

from otx.algorithms.common.adapters.torch.dataloaders.samplers import OTXSampler
from otx.algorithms.common.utils.logger import get_logger

logger = get_logger()


@HOOKS.register_module()
class AdaptiveRepeatDataHook(Hook):
"""Hook that adaptively repeats the dataset to control the number of iterations.

Args:
coef (float, optional) : coefficient that effects to number of repeats
(coef * math.sqrt(num_iters-1)) +5
min_repeat (float, optional) : minimum repeats
"""

def __init__(self, coef: float = -0.7, min_repeat: float = 1.0):
self.coef = coef
self.min_repeat = min_repeat

def before_epoch(self, runner):
"""Convert to OTX Sampler."""
dataset = runner.data_loader.dataset
batch_size = runner.data_loader.batch_size
num_workers = runner.data_loader.num_workers
collate_fn = runner.data_loader.collate_fn
worker_init_fn = runner.data_loader.worker_init_fn
rank, world_size = get_dist_info()

sampler = OTXSampler(
dataset=dataset,
samples_per_gpu=batch_size,
use_adaptive_repeats=True,
num_replicas=world_size,
rank=rank,
shuffle=True,
coef=self.coef,
min_repeat=self.min_repeat,
)

runner.data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=False,
worker_init_fn=worker_init_fn,
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TaskAdaptHook(Hook):
model_type (str): Types of models used for learning
sampler_flag (bool): Flag about using ClsIncrSampler
efficient_mode (bool): Flag about using efficient mode sampler
use_adaptive_repeat (bool): Flag about using adaptive repeat data
"""

def __init__(
Expand All @@ -35,18 +36,21 @@ def __init__(
sampler_flag=False,
sampler_type="cls_incr",
efficient_mode=False,
use_adaptive_repeat=False,
):
self.src_classes = src_classes
self.dst_classes = dst_classes
self.model_type = model_type
self.sampler_flag = sampler_flag
self.sampler_type = sampler_type
self.efficient_mode = efficient_mode
self.use_adaptive_repeat = use_adaptive_repeat

logger.info(f"Task Adaptation: {self.src_classes} => {self.dst_classes}")
logger.info(f"- Efficient Mode: {self.efficient_mode}")
logger.info(f"- Sampler type: {self.sampler_type}")
logger.info(f"- Sampler flag: {self.sampler_flag}")
logger.info(f"- Adaptive repeat: {self.use_adaptive_repeat}")

def before_epoch(self, runner):
"""Produce a proper sampler for task-adaptation."""
Expand All @@ -59,11 +63,21 @@ def before_epoch(self, runner):
rank, world_size = get_dist_info()
if self.sampler_type == "balanced":
sampler = BalancedSampler(
dataset, batch_size, efficient_mode=self.efficient_mode, num_replicas=world_size, rank=rank
dataset,
batch_size,
efficient_mode=self.efficient_mode,
num_replicas=world_size,
rank=rank,
use_adaptive_repeats=self.use_adaptive_repeat,
)
else:
sampler = ClsIncrSampler(
dataset, batch_size, efficient_mode=self.efficient_mode, num_replicas=world_size, rank=rank
dataset,
batch_size,
efficient_mode=self.efficient_mode,
num_replicas=world_size,
rank=rank,
use_adaptive_repeats=self.use_adaptive_repeat,
)
runner.data_loader = DataLoader(
dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# flake8: noqa

from .otx_sampler import OTXSampler
from .balanced_sampler import BalancedSampler
from .cls_incr_sampler import ClsIncrSampler

__all__ = ["BalancedSampler", "ClsIncrSampler"]
__all__ = ["OTXSampler", "BalancedSampler", "ClsIncrSampler"]
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Balanced sampler for imbalanced data."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import math

import numpy as np
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset

from otx.algorithms.common.utils.logger import get_logger

from .otx_sampler import OTXSampler

logger = get_logger()


class BalancedSampler(Sampler): # pylint: disable=too-many-instance-attributes
class BalancedSampler(OTXSampler): # pylint: disable=too-many-instance-attributes
"""Balanced sampler for imbalanced data for class-incremental task.

This sampler is a sampler that creates an effective batch
Expand All @@ -21,24 +27,40 @@ class BalancedSampler(Sampler): # pylint: disable=too-many-instance-attributes
dataset (Dataset): A built-up dataset
samples_per_gpu (int): batch size of Sampling
efficient_mode (bool): Flag about using efficient mode
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
use_adaptive_repeats (bool, optional): Flag about using adaptive repeats
"""

def __init__(self, dataset, batch_size, efficient_mode=True, num_replicas=1, rank=0, drop_last=False):
self.batch_size = batch_size
self.repeat = 1
if hasattr(dataset, "times"):
self.repeat = dataset.times
if hasattr(dataset, "dataset"):
self.dataset = dataset.dataset
else:
self.dataset = dataset
self.img_indices = self.dataset.img_indices
self.num_cls = len(self.img_indices.keys())
self.data_length = len(self.dataset)
def __init__(
self,
dataset: Dataset,
samples_per_gpu: int,
efficient_mode: bool = False,
num_replicas: int = 1,
rank: int = 0,
drop_last: bool = False,
use_adaptive_repeats: bool = False,
):
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.drop_last = drop_last

super().__init__(dataset, samples_per_gpu, use_adaptive_repeats)

self.img_indices = self.dataset.img_indices # type: ignore[attr-defined]
self.num_cls = len(self.img_indices.keys())
self.data_length = len(self.dataset)

if efficient_mode:
# Reduce the # of sampling (sampling data for a single epoch)
self.num_tail = min(len(cls_indices) for cls_indices in self.img_indices.values())
Expand Down Expand Up @@ -109,7 +131,6 @@ def __iter__(self):
]

assert len(indices) == self.num_samples

return iter(indices)

def __len__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import random

import numpy as np
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset

from otx.algorithms.common.utils.task_adapt import unwrap_dataset
from .otx_sampler import OTXSampler


class ClsIncrSampler(Sampler): # pylint: disable=too-many-instance-attributes
class ClsIncrSampler(OTXSampler): # pylint: disable=too-many-instance-attributes
"""Sampler for Class-Incremental Task.

This sampler is a sampler that creates an effective batch
Expand All @@ -25,16 +25,35 @@ class ClsIncrSampler(Sampler): # pylint: disable=too-many-instance-attributes
dataset (Dataset): A built-up dataset
samples_per_gpu (int): batch size of Sampling
efficient_mode (bool): Flag about using efficient mode
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
use_adaptive_repeats (bool, optional): Flag about using adaptive repeats
"""

def __init__(self, dataset, samples_per_gpu, efficient_mode=False, num_replicas=1, rank=0, drop_last=False):
def __init__(
self,
dataset: Dataset,
samples_per_gpu: int,
efficient_mode: bool = False,
num_replicas: int = 1,
rank: int = 0,
drop_last: bool = False,
use_adaptive_repeats: bool = False,
):
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.drop_last = drop_last

# Dataset Wrapping remove & repeat for RepeatDataset
self.dataset, self.repeat = unwrap_dataset(dataset)
super().__init__(dataset, samples_per_gpu, use_adaptive_repeats)

if hasattr(self.dataset, "img_indices"):
self.new_indices = self.dataset.img_indices["new"]
Expand Down