New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] training time validation bug fix #611
Conversation
Codecov Report
@@ Coverage Diff @@
## master #611 +/- ##
==========================================
+ Coverage 81.33% 81.78% +0.44%
==========================================
Files 116 116
Lines 6720 6709 -11
Branches 1156 1151 -5
==========================================
+ Hits 5466 5487 +21
+ Misses 1103 1069 -34
- Partials 151 153 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Please also add it in |
fixed |
mmcls/datasets/builder.py
Outdated
@@ -83,6 +102,8 @@ def build_dataloader(dataset, | |||
""" | |||
rank, world_size = get_dist_info() | |||
|
|||
sampler_cfg = get_sampler_cfg(cfg, distributed=dist) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a good choice to pass the whole config into a module's builder.
What if we change like this?
rank, world_size = get_dist_info()
# default logic
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
# build custom sampler logic
if sampler_cfg:
# shuffle=False when val and test
sampler_cfg.update(shuffle=shuffle)
sampler = build_sampler(
sampler_cfg,
default_args=dict(dataset=dataset, num_replicas=world_size, rank=rank))
# If sampler exists, turn off dataloader shuffle
if sampler is not None:
shuffle = False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to use:
rank, world_size = get_dist_info()
# Custom sampler logic
if sampler_cfg:
# shuffle=False when val and test
sampler_cfg.update(shuffle=shuffle)
sampler = build_sampler(
sampler_cfg,
default_args=dict(dataset=dataset, num_replicas=world_size, rank=rank))
# Default sampler logic
elif dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
else:
sampler = None
# If sampler exists, turn off dataloader shuffle
if sampler is not None:
shuffle = False
if dist:
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
Then we can avoid an unnecessary instantiation of DistributedSampler
. And split the logic of sampler and batch_size.
mmcls/datasets/builder.py
Outdated
@@ -23,6 +23,25 @@ | |||
SAMPLERS = Registry('sampler') | |||
|
|||
|
|||
def get_sampler_cfg(cfg, distributed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MME need to handle DistributedSampler
and InfiniteSampler
as default sampler when runner type is different, so there is a set_default_sampler_cfg
function.
But MMClassification only has one default sampler: DistributedSampler
. I think this function is unnecessary here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* sampler bugfixes * sampler bugfixes * reorganize code * minor fixes * reorganize code * minor fixes * Use `mmcv.runner.get_dist_info` instead of `dist` package to get rank and world size. * Add `build_dataloader` unit tests and fix sampler's unit tests. * Fix unit tests * Fix unit tests Co-authored-by: mzr1996 <mzr1996@163.com>
Motivation
This PR fixes the bug introduced in #588
Modification
add sampler_cfg for validation in mmcls/api/train.py