Skip to content

Commit

Permalink
[Enhancement] Rename fast_conv_bn_eval to efficient_conv_bn_eval (#1251)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Jul 15, 2023
1 parent 276c614 commit 66d828d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 33 deletions.
4 changes: 2 additions & 2 deletions docs/en/common_usage/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

Memory capacity is critical in deep learning training and inference and determines whether the model can run successfully. Common memory saving approaches include:

- Enable Fast Conv BN Eval Feature (Experimental)
- Enable Efficient Conv BN Eval Feature (Experimental)

We've recently [introduced](https://github.com/open-mmlab/mmcv/pull/2807) an experimental feature in MMCV: the Fast Conv BN Eval, based on the concepts discussed in [this paper](https://arxiv.org/abs/2305.11624). This feature has been designed with the aim of reducing memory footprint during network training without hurting performance. If your network architecture contains a series of consecutive Conv+BN blocks, and these normalization layers are maintained in `eval` mode during the training process (a common occurrence when training object detectors with [MMDetection](https://github.com/open-mmlab/mmdetection)), this feature could reduce memory consumption by more than $20%$. To enable the Fast Conv BN Eval feature, simply add the following command-line arguments: `--cfg-options fast_conv_bn_eval="[backbone]"`. When you see `Enabling the "fast_conv_bn_eval" feature for these modules ...` in the output log, the feature is successfully enabled. As this is currently in an experimental phase, we are eagerly looking forward to hearing about your experience with it. Please share your usage reports, observations, and suggestions at [this discussion thread](https://github.com/open-mmlab/mmcv/discussions/2841). Your feedback is crucial for further development and for determining whether this feature should be integrated into the stable release.
We've recently [introduced](https://github.com/open-mmlab/mmcv/pull/2807) an experimental feature in MMCV: the Efficient Conv BN Eval, based on the concepts discussed in [this paper](https://arxiv.org/abs/2305.11624). This feature has been designed with the aim of reducing memory footprint during network training without hurting performance. If your network architecture contains a series of consecutive Conv+BN blocks, and these normalization layers are maintained in `eval` mode during the training process (a common occurrence when training object detectors with [MMDetection](https://github.com/open-mmlab/mmdetection)), this feature could reduce memory consumption by more than $20%$. To enable the Efficient Conv BN Eval feature, simply add the following command-line arguments: `--cfg-options efficient_conv_bn_eval="[backbone]"`. When you see `Enabling the "efficient_conv_bn_eval" feature for these modules ...` in the output log, the feature is successfully enabled. As this is currently in an experimental phase, we are eagerly looking forward to hearing about your experience with it. Please share your usage reports, observations, and suggestions at [this discussion thread](https://github.com/open-mmlab/mmengine/discussions/1252). Your feedback is crucial for further development and for determining whether this feature should be integrated into the stable release.

- Gradient Accumulation

Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/common_usage/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

在深度学习训练推理过程中显存容量至关重要,其决定了模型是否能成功运行。常见的节省显存办法包括:

- 启用快速卷积BN评估功能(实验性)
- 启用高效卷积BN评估功能(实验性)

基于在[这篇论文](https://arxiv.org/abs/2305.11624)中讨论的概念,我们最近在MMCV中[引入](https://github.com/open-mmlab/mmcv/pull/2807)了一个实验性功能:快速卷积BN评估。这个功能的设计目标是在不损害性能的情况下减少网络训练过程中的显存占用。如果你的网络架构包含了一系列连续的Conv+BN模块,而且这些BN层在训练过程中保持在 `eval` 模式(在使用 [MMDetection](https://github.com/open-mmlab/mmdetection)训练对象检测器时很常见),这个功能可以将显存消耗减少超过 $20%$。要启用快速卷积BN评估功能,只需添加以下命令行参数:`--cfg-options fast_conv_bn_eval="[backbone]"`。当你在输出日志中看到 `Enabling the "fast_conv_bn_eval" feature for these modules ...`时,意味着功能已成功启用。由于这仍处于实验阶段,我们非常期待听到你对它的使用体验。请在[这个讨论线程](https://github.com/open-mmlab/mmcv/discussions/2841)分享你的使用报告、观察和建议。你的反馈对于进一步的开发和确定是否应将此功能集成到稳定版中至关重要。
基于在[这篇论文](https://arxiv.org/abs/2305.11624)中讨论的概念,我们最近在MMCV中[引入](https://github.com/open-mmlab/mmcv/pull/2807)了一个实验性功能:高效卷积BN评估。这个功能的设计目标是在不损害性能的情况下减少网络训练过程中的显存占用。如果你的网络架构包含了一系列连续的Conv+BN模块,而且这些BN层在训练过程中保持在 `eval` 模式(在使用 [MMDetection](https://github.com/open-mmlab/mmdetection)训练对象检测器时很常见),这个功能可以将显存消耗减少超过 $20%$。要启用高效卷积BN评估功能,只需添加以下命令行参数:`--cfg-options efficient_conv_bn_eval="[backbone]"`。当你在输出日志中看到 `Enabling the "efficient_conv_bn_eval" feature for these modules ...`时,意味着功能已成功启用。由于这仍处于实验阶段,我们非常期待听到你对它的使用体验。请在[这个讨论线程](https://github.com/open-mmlab/mmengine/discussions/1252)分享你的使用报告、观察和建议。你的反馈对于进一步的开发和确定是否应将此功能集成到稳定版中至关重要。

- 梯度累加

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import torch.nn as nn


def fast_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd, x: torch.Tensor):
def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""Code borrowed from mmcv 2.0.1, so that this feature can be used for old
mmcv versions.
Expand Down Expand Up @@ -68,27 +69,28 @@ def bn_once_identity_forward(bn: nn.modules.batchnorm._BatchNorm,
return x


def fast_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd, x: torch.Tensor):
"""This function controls whether to use `fast_conv_bn_eval_forward`.
def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""This function controls whether to use `efficient_conv_bn_eval_forward`.
If the following `bn` is in `eval` mode, then we turn on the special
`fast_conv_bn_eval_forward` and let the following call of `bn.forward` to
be identity. Note that this `bn.forward` modification only works for one
`efficient_conv_bn_eval_forward` and let the following call of `bn.forward`
to be identity. Note that this `bn.forward` modification only works for one
call. After the call, `bn.forward` will be restored to the default
function. This is to deal with the case where one `bn` module is used in
multiple places.
"""
if not bn.training:
# bn in eval mode
output = fast_conv_bn_eval_forward(bn, conv, x)
output = efficient_conv_bn_eval_forward(bn, conv, x)
bn.forward = partial(bn_once_identity_forward, bn)
return output
else:
return conv._conv_forward(x, conv.weight, conv.bias)


def turn_on_fast_conv_bn_eval_for_single_model(model: torch.nn.Module):
def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
# optimize consecutive conv+bn by modifying forward function
# Symbolically trace the input model to create an FX GraphModule
import torch.fx as fx
Expand Down Expand Up @@ -132,14 +134,14 @@ def turn_on_fast_conv_bn_eval_for_single_model(model: torch.nn.Module):
conv_module = modules[conv_name]
bn_module = modules[bn_name]

conv_module.forward = partial(fast_conv_bn_eval_control, bn_module,
conv_module)
conv_module.forward = partial(efficient_conv_bn_eval_control,
bn_module, conv_module)


def turn_on_fast_conv_bn_eval(model: torch.nn.Module, modules: Union[List[str],
str]):
def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
modules: Union[List[str], str]):
if isinstance(modules, str):
modules = [modules]
for module_name in modules:
module = attrgetter(module_name)(model)
turn_on_fast_conv_bn_eval_for_single_model(module)
turn_on_efficient_conv_bn_eval_for_single_model(module)
11 changes: 6 additions & 5 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.model.fast_conv_bn_eval import turn_on_fast_conv_bn_eval
from mmengine.model.efficient_conv_bn_eval import \
turn_on_efficient_conv_bn_eval
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
Expand Down Expand Up @@ -1721,12 +1722,12 @@ def train(self) -> nn.Module:
# initialize the model weights
self._init_model_weights()

# try to enable fast_conv_bn_eval feature
modules = self.cfg.get('fast_conv_bn_eval', None)
# try to enable efficient_conv_bn_eval feature
modules = self.cfg.get('efficient_conv_bn_eval', None)
if modules is not None:
self.logger.info(f'Enabling the "fast_conv_bn_eval" feature'
self.logger.info(f'Enabling the "efficient_conv_bn_eval" feature'
f' for sub-modules: {modules}')
turn_on_fast_conv_bn_eval(ori_model, modules)
turn_on_efficient_conv_bn_eval(ori_model, modules)

# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest import TestCase

import torch
from torch import nn

from mmengine.model.fast_conv_bn_eval import \
turn_on_fast_conv_bn_eval_for_single_model
from mmengine.model.efficient_conv_bn_eval import \
turn_on_efficient_conv_bn_eval_for_single_model
from mmengine.testing import assert_allclose
from mmengine.utils import is_installed
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version

mmcv_is_installed = is_installed('mmcv')

Expand All @@ -30,28 +33,31 @@ def __init__(self, *args, **kwargs) -> None:

def forward(self, x):
if mmcv_is_installed:
# this ConvModule can use fast_conv_bn_eval feature
# this ConvModule can use efficient_conv_bn_eval feature
x = self.mod1(x)
# this conv-bn pair can use fast_conv_bn_eval feature
# this conv-bn pair can use efficient_conv_bn_eval feature
x = self.bn1(self.conv1(x))
# this conv-bn pair cannot use fast_conv_bn_eval feature
# this conv-bn pair cannot use efficient_conv_bn_eval feature
# because `self.conv2` is used twice
x = self.bn2(self.conv2(self.conv2(x)))
# this conv-bn pair can use fast_conv_bn_eval feature
# this conv-bn pair can use efficient_conv_bn_eval feature
# just for the first forward of the `self.bn3`
x = self.bn3(self.bn3(self.conv3(x)))
return x


class TestFastConvBNEval(TestCase):
"""Test the turn_on_fast_conv_bn_eval function."""
@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.8'),
reason='torch.fx needs Pytorch 1.8 or higher')
class TestEfficientConvBNEval(TestCase):
"""Test the turn_on_efficient_conv_bn_eval function."""

def test_fast_conv_bn_eval(self):
def test_efficient_conv_bn_eval(self):
model = BackboneModel()
model.eval()
input = torch.randn(64, 6, 32, 32)
output = model(input)
turn_on_fast_conv_bn_eval_for_single_model(model)
turn_on_efficient_conv_bn_eval_for_single_model(model)
output2 = model(input)
print((output - output2).abs().max().item())
assert_allclose(output, output2)

0 comments on commit 66d828d

Please sign in to comment.