Skip to content

Commit

Permalink
[Enhance] Support callable collate_fn for FlexibleRunner (#1284)
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla committed Aug 1, 2023
1 parent d480df7 commit d772ad0
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions mmengine/runner/_flexible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import mmengine
from mmengine._strategy import BaseStrategy
from mmengine.config import Config, ConfigDict
from mmengine.dataset import worker_init_fn
from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.dist import get_rank, infer_launcher, master_only
from mmengine.evaluator import Evaluator
from mmengine.fileio import FileClient, join_path
Expand Down Expand Up @@ -857,22 +857,28 @@ def build_dataloader(

# build dataloader
init_fn: Optional[partial]

if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(
disable_subprocess_warning,
bool), ('disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial(
worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
if 'worker_init_fn' in dataloader_cfg:
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
worker_init_fn_type = worker_init_fn_cfg.pop('type')
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
assert callable(worker_init_fn)
init_fn = partial(worker_init_fn,
**worker_init_fn_cfg) # type: ignore
else:
init_fn = None
if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(disable_subprocess_warning, bool), (
'disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial(
default_worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else:
init_fn = None

# `persistent_workers` requires pytorch version >= 1.7
if ('persistent_workers' in dataloader_cfg
Expand All @@ -891,9 +897,19 @@ def build_dataloader(
# samples into a dict without stacking the batch tensor.
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
dict(type='pseudo_collate'))
collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = FUNCTIONS.get(collate_fn_type)
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
if isinstance(collate_fn_cfg, dict):
collate_fn_type = collate_fn_cfg.pop('type')
if isinstance(collate_fn_type, str):
collate_fn = FUNCTIONS.get(collate_fn_type)
else:
collate_fn = collate_fn_type
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
elif callable(collate_fn_cfg):
collate_fn = collate_fn_cfg
else:
raise TypeError(
'collate_fn should be a dict or callable object, but got '
f'{collate_fn_cfg}')
data_loader = DataLoader(
dataset=dataset,
sampler=sampler if batch_sampler is None else None,
Expand Down

0 comments on commit d772ad0

Please sign in to comment.