Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Update] Infinite Sampler #1508

Draft
wants to merge 2 commits into
base: 0.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
build_optimizer, get_dist_info)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook,
OmniSourceRunner)
from ..core import (DistEvalHook, EvalHook, InfiniteEpochBasedRunner,
OmniSourceDistSamplerSeedHook, OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import (PreciseBNHook, build_ddp, build_dp, default_device,
get_root_logger)
Expand Down Expand Up @@ -91,7 +91,8 @@ def train_model(model,
persistent_workers=cfg.data.get('persistent_workers', False),
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed)
seed=cfg.seed,
use_infinite_sampler=cfg.use_infinite_sampler)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('train_dataloader', {}))

Expand Down Expand Up @@ -137,7 +138,12 @@ def train_model(model,
# build runner
optimizer = build_optimizer(model, cfg.optimizer)

Runner = OmniSourceRunner if cfg.omnisource else EpochBasedRunner
if cfg.omnisource:
Runner = OmniSourceRunner
elif cfg.use_infinite_sampler:
Runner = InfiniteEpochBasedRunner
else:
Runner = EpochBasedRunner
runner = Runner(
model,
optimizer=optimizer,
Expand Down Expand Up @@ -265,7 +271,8 @@ def train_model(model,
persistent_workers=cfg.data.get('persistent_workers', False),
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=False)
shuffle=False,
use_infinite_sampler=cfg.use_infinite_sampler)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('test_dataloader', {}))

Expand Down
6 changes: 5 additions & 1 deletion mmaction/core/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .infinite_runner import InfiniteEpochBasedRunner
from .omnisource_runner import OmniSourceDistSamplerSeedHook, OmniSourceRunner

__all__ = ['OmniSourceRunner', 'OmniSourceDistSamplerSeedHook']
__all__ = [
'OmniSourceRunner', 'OmniSourceDistSamplerSeedHook',
'InfiniteEpochBasedRunner'
]
47 changes: 47 additions & 0 deletions mmaction/core/runner/infinite_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time

from mmcv.runner import EpochBasedRunner
from mmcv.runner.builder import RUNNERS
from torch.utils.data import DataLoader


@RUNNERS.register_module()
class InfiniteEpochBasedRunner(EpochBasedRunner):
"""Epoch-based Runner supports dataloader with InfiniteSampler.

The workers of dataloader will re-initialize, when the iterator of
dataloader is created. InfiniteSampler is designed to avoid these time
consuming operations, since the iterator with InfiniteSampler will never
reach the end.
"""

def train(self, data_loader: DataLoader, **kwargs) -> None:
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition

# To reuse the iterator, we only create iterator once and bind it
# with runner. In the next epoch, the iterator will be used against
if not hasattr(self, 'data_loader_iter'):
self.data_loader_iter = iter(self.data_loader)

# The InfiniteSampler will never reach the end, but we set the
# length of InfiniteSampler to the actual length of dataset.
# The length of dataloader is determined by the length of sampler,
# when the sampler is not None. Therefore, we can simply forward the
# whole dataset in a epoch by length of dataloader.

for i in range(len(self.data_loader)):
data_batch = next(self.data_loader_iter)
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1

self.call_hook('after_train_epoch')
self._epoch += 1
2 changes: 2 additions & 0 deletions mmaction/core/runner/omnisource_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mmcv
from mmcv.runner import EpochBasedRunner, Hook
from mmcv.runner.builder import RUNNERS
from mmcv.runner.utils import get_host_info


Expand All @@ -28,6 +29,7 @@ def before_epoch(self, runner):
data_loader.batch_sampler.sampler.set_epoch(runner.epoch)


@RUNNERS.register_module()
class OmniSourceRunner(EpochBasedRunner):
"""OmniSource Epoch-based Runner.

Expand Down
21 changes: 16 additions & 5 deletions mmaction/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from torch.utils.data import DataLoader

from ..utils.multigrid import ShortCycleSampler
from .samplers import ClassSpecificDistributedSampler, DistributedSampler
from .samplers import (ClassSpecificDistributedSampler,
DistributedInfiniteSampler, DistributedSampler,
InfiniteSampler)

if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
Expand Down Expand Up @@ -51,6 +53,7 @@ def build_dataloader(dataset,
drop_last=False,
pin_memory=True,
persistent_workers=False,
use_infinite_sampler=False,
**kwargs):
"""Build PyTorch DataLoader.

Expand All @@ -77,7 +80,11 @@ def build_dataloader(dataset,
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.8.0.
Default: False
Default: False.
use_infinite_sampler (bool): Whether to use infinite sampler.
Noted that infinite sampler will keep iterator of dataloader
running forever, which can avoid the overhead of worker
initialization between epochs. Default: False.
kwargs (dict, optional): Any keyword argument to be used to initialize
DataLoader.

Expand All @@ -92,7 +99,10 @@ def build_dataloader(dataset,
crop_size = kwargs.pop('crop_size', 224)

if dist:
if sample_by_class:
if use_infinite_sampler:
sampler = DistributedInfiniteSampler(
dataset, world_size, rank, shuffle=shuffle)
elif sample_by_class:
dynamic_length = getattr(dataset, 'dynamic_length', True)
sampler = ClassSpecificDistributedSampler(
dataset,
Expand Down Expand Up @@ -132,7 +142,8 @@ def build_dataloader(dataset,
raise NotImplementedError(
'Short cycle using non-dist is not supported')

sampler = None
sampler = InfiniteSampler(dataset, seed=seed, shuffle=shuffle) \
if use_infinite_sampler else None
batch_size = num_gpus * videos_per_gpu
num_workers = num_gpus * workers_per_gpu

Expand All @@ -150,7 +161,7 @@ def build_dataloader(dataset,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=videos_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
shuffle=shuffle if sampler is None else None,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)
Expand Down
9 changes: 8 additions & 1 deletion mmaction/datasets/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distributed_sampler import (ClassSpecificDistributedSampler,
DistributedSampler)
from .infinite_sampler import (DistributedInfiniteGroupSampler,
DistributedInfiniteSampler,
InfiniteGroupSampler, InfiniteSampler)

__all__ = ['DistributedSampler', 'ClassSpecificDistributedSampler']
__all__ = [
'DistributedSampler', 'ClassSpecificDistributedSampler', 'InfiniteSampler',
'InfiniteGroupSampler', 'DistributedInfiniteSampler',
'DistributedInfiniteGroupSampler'
]
Loading