Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Convert SyncBN to BN when training on DP #772

Merged
merged 14 commits into from Sep 15, 2021
8 changes: 8 additions & 0 deletions docs/train.md
Expand Up @@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c

### Train with a single GPU

official support:

```shell
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
```

experimental support (Convert SyncBN to BN):

```shell
python tools/train.py ${CONFIG_FILE} [optional arguments]
```
Expand Down
4 changes: 3 additions & 1 deletion mmseg/models/utils/__init__.py
Expand Up @@ -6,10 +6,12 @@
from .se_layer import SELayer
from .self_attention_block import SelfAttentionBlock
from .shape_convert import nchw_to_nlc, nlc_to_nchw
from .syncbn2bn import revert_sync_batchnorm
from .up_conv_block import UpConvBlock

__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw',
'revert_sync_batchnorm'
]
28 changes: 28 additions & 0 deletions mmseg/models/utils/syncbn2bn.py
@@ -0,0 +1,28 @@
"""Modified from
https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547."""

from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm


def revert_sync_batchnorm(module):
# this is very similar to the function that it is trying to revert:
# https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output
27 changes: 3 additions & 24 deletions tests/test_models/test_forward.py
Expand Up @@ -7,7 +7,8 @@
import pytest
import torch
import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm

from mmseg.models.utils import revert_sync_batchnorm


def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
Expand Down Expand Up @@ -183,28 +184,6 @@ def _check_input_dim(self, inputs):
pass


def _convert_batchnorm(module):
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output


@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
_check_input_dim)
@patch('torch.distributed.get_world_size', get_world_size)
Expand Down Expand Up @@ -235,7 +214,7 @@ def _test_encoder_decoder_forward(cfg_file):
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
else:
segmentor = _convert_batchnorm(segmentor)
segmentor = revert_sync_batchnorm(segmentor)

# Test forward train
losses = segmentor.forward(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_models/test_utils/test_syncbn2bn.py
@@ -0,0 +1,14 @@
import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm

from mmseg.models.utils import revert_sync_batchnorm


def test_syncbn2bn():
model = nn.Module()
model.add_module('SyncBN', SyncBatchNorm(1))
model = revert_sync_batchnorm(model)

for m in model.modules():
if isinstance(m, SyncBatchNorm):
raise TypeError
10 changes: 10 additions & 0 deletions tools/train.py
Expand Up @@ -3,6 +3,7 @@
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
Expand All @@ -13,6 +14,7 @@
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.models.utils import revert_sync_batchnorm
from mmseg.utils import collect_env, get_root_logger


Expand Down Expand Up @@ -133,6 +135,14 @@ def main():
test_cfg=cfg.get('test_cfg'))
model.init_weights()

# SyncBN is not support for single gpu
sennnnn marked this conversation as resolved.
Show resolved Hide resolved
if not distributed:
warnings.warn(
'SyncBN only support DDP. In order to compat with DP, we convert '
sennnnn marked this conversation as resolved.
Show resolved Hide resolved
'SyncBN tp BN. Please to use dist_train.py which has official '
sennnnn marked this conversation as resolved.
Show resolved Hide resolved
'support to avoid this problem.')
sennnnn marked this conversation as resolved.
Show resolved Hide resolved
model = revert_sync_batchnorm(model)

logger.info(model)

datasets = [build_dataset(cfg.data.train)]
Expand Down