Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Turning On fast_conv_bn_eval #1202

Merged
merged 26 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
970bd28
draft implementation of turning on fast_conv_bn_eval feature
youkaichao Jun 15, 2023
cdd5d82
add support for specifying the parts of network to turn on fast_conv_…
youkaichao Jun 15, 2023
32ddc9c
use ConvModule.turn_on_fast_conv_bn_eval to reduce repetitive code
youkaichao Jun 15, 2023
4efbaaa
add more logging to indicate if the feature can be enabled
youkaichao Jun 20, 2023
0f0061b
update documentation to introduce fast_conv_bn_eval
youkaichao Jun 20, 2023
d9fe010
raise error instead of warning when fast_conv_bn_eval is configed but…
youkaichao Jul 4, 2023
e3ddabc
use one argument fast_conv_bn_eval instead of two to simplify usage
youkaichao Jul 4, 2023
e0cfd55
simplify the complex if-else logic
youkaichao Jul 4, 2023
804c7fd
use more verbose names for variables
youkaichao Jul 4, 2023
af318cc
correct found_pair logic and check if conv/bn are used multiple times
youkaichao Jul 4, 2023
23bc688
add testcase for turn_on_fast_conv_bn_eval function
youkaichao Jul 5, 2023
30622cd
rename file
youkaichao Jul 10, 2023
7834610
re-organize files
youkaichao Jul 10, 2023
6f04d01
add conditional skip in testing
youkaichao Jul 10, 2023
4e434fb
move str case disposal out of runner
youkaichao Jul 10, 2023
0de92fa
remove redundant code
youkaichao Jul 10, 2023
df91f3e
add support for native torch conv and bn, independent with mmcv
youkaichao Jul 11, 2023
8f5e48c
add comments
youkaichao Jul 13, 2023
8565608
update function name
youkaichao Jul 13, 2023
4954c16
relex the multi-bn usage
youkaichao Jul 13, 2023
09badf1
update test code
youkaichao Jul 13, 2023
bee4711
remove unnecessary global
youkaichao Jul 13, 2023
ba29292
update comment
youkaichao Jul 13, 2023
9d01ea0
remove dependency on mmcv and ConvModule
youkaichao Jul 14, 2023
ab4d587
unconditional import in runner
youkaichao Jul 14, 2023
4be634c
move torch.fx import into functions
youkaichao Jul 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/en/common_usage/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Memory capacity is critical in deep learning training and inference and determines whether the model can run successfully. Common memory saving approaches include:

- Enable Fast Conv BN Eval Feature (Experimental)

We've recently [introduced](https://github.com/open-mmlab/mmcv/pull/2807) an experimental feature in MMCV: the Fast Conv BN Eval, based on the concepts discussed in [this paper](https://arxiv.org/abs/2305.11624). This feature has been designed with the aim of reducing memory footprint during network training without hurting performance. If your network architecture contains a series of consecutive Conv+BN blocks, and these normalization layers are maintained in `eval` mode during the training process (a common occurrence when training object detectors with [MMDetection](https://github.com/open-mmlab/mmdetection)), this feature could reduce memory consumption by more than $20%$. To enable the Fast Conv BN Eval feature, simply add the following command-line arguments: `--cfg-options fast_conv_bn_eval="[backbone]"`. When you see `Enabling the "fast_conv_bn_eval" feature for these modules ...` in the output log, the feature is successfully enabled. As this is currently in an experimental phase, we are eagerly looking forward to hearing about your experience with it. Please share your usage reports, observations, and suggestions at [this discussion thread](https://github.com/open-mmlab/mmcv/discussions/2841). Your feedback is crucial for further development and for determining whether this feature should be integrated into the stable release.

- Gradient Accumulation

Gradient accumulation is the mechanism that runs at a configured number of steps accumulating the gradients instead of updating parameters, after which the network parameters are updated and the gradients are cleared. With this technique of delayed parameter update, the result is similar to those scenarios using a large batch size, while the memory of activation can be saved. However, it should be noted that if the model contains a batch normalization layer, using gradient accumulation will impact performance.
Expand Down
4 changes: 4 additions & 0 deletions docs/zh_cn/common_usage/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

在深度学习训练推理过程中显存容量至关重要,其决定了模型是否能成功运行。常见的节省显存办法包括:

- 启用快速卷积BN评估功能(实验性)

基于在[这篇论文](https://arxiv.org/abs/2305.11624)中讨论的概念,我们最近在MMCV中[引入](https://github.com/open-mmlab/mmcv/pull/2807)了一个实验性功能:快速卷积BN评估。这个功能的设计目标是在不损害性能的情况下减少网络训练过程中的显存占用。如果你的网络架构包含了一系列连续的Conv+BN模块,而且这些BN层在训练过程中保持在 `eval` 模式(在使用 [MMDetection](https://github.com/open-mmlab/mmdetection)训练对象检测器时很常见),这个功能可以将显存消耗减少超过 $20%$。要启用快速卷积BN评估功能,只需添加以下命令行参数:`--cfg-options fast_conv_bn_eval="[backbone]"`。当你在输出日志中看到 `Enabling the "fast_conv_bn_eval" feature for these modules ...`时,意味着功能已成功启用。由于这仍处于实验阶段,我们非常期待听到你对它的使用体验。请在[这个讨论线程](https://github.com/open-mmlab/mmcv/discussions/2841)分享你的使用报告、观察和建议。你的反馈对于进一步的开发和确定是否应将此功能集成到稳定版中至关重要。

- 梯度累加

梯度累加是指在每计算一个批次的梯度后,不进行清零而是进行梯度累加,当累加到一定的次数之后,再更新网络参数和梯度清零。 通过这种参数延迟更新的手段,实现与采用大 batch 尺寸相近的效果,达到节省显存的目的。但是需要注意如果模型中包含 batch normalization 层,使用梯度累加会对性能有一定影响。
Expand Down
145 changes: 145 additions & 0 deletions mmengine/model/fast_conv_bn_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from operator import attrgetter
from typing import List, Union

import torch
import torch.nn as nn


def fast_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd, x: torch.Tensor):
"""Code borrowed from mmcv 2.0.1, so that this feature can be used for old
mmcv versions.

Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.
Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)

if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)

if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)

# shape of [C_out, 1, 1, 1] in Conv2d
weight_coeff = torch.rsqrt(bn.running_var +
bn.eps).reshape([-1] + [1] *
(len(conv.weight.shape) - 1))
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff

# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
(bias_on_the_fly - bn.running_mean)

return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


def bn_once_identity_forward(bn: nn.modules.batchnorm._BatchNorm,
x: torch.Tensor):
"""The forward function is an identity function.

The magic is that after one call, the `bn.forward` will be restored to what
it used to be.
"""
bn.__dict__.pop('forward')
return x


def fast_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd, x: torch.Tensor):
"""This function controls whether to use `fast_conv_bn_eval_forward`.

If the following `bn` is in `eval` mode, then we turn on the special
`fast_conv_bn_eval_forward` and let the following call of `bn.forward` to
be identity. Note that this `bn.forward` modification only works for one
call. After the call, `bn.forward` will be restored to the default
function. This is to deal with the case where one `bn` module is used in
multiple places.
"""
if not bn.training:
# bn in eval mode
output = fast_conv_bn_eval_forward(bn, conv, x)
bn.forward = partial(bn_once_identity_forward, bn)
return output
else:
return conv._conv_forward(x, conv.weight, conv.bias)


def turn_on_fast_conv_bn_eval_for_single_model(model: torch.nn.Module):
# optimize consecutive conv+bn by modifying forward function
# Symbolically trace the input model to create an FX GraphModule
import torch.fx as fx
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())

patterns = [(torch.nn.modules.conv._ConvNd,
torch.nn.modules.batchnorm._BatchNorm)]

# Iterate through nodes in the graph to find ConvBN blocks
for node in fx_model.graph.nodes:
# If our current node isn't calling a Module then we can ignore it.
if node.op != 'call_module':
continue
target_module = modules[node.target]
found_pair = False
for conv_class, bn_class in patterns:
if isinstance(target_module, bn_class):
source_module = modules[node.args[0].target]
if isinstance(source_module, conv_class):
found_pair = True
# Not a conv-BN pattern or output of conv is used by other nodes
if not found_pair or len(node.args[0].users) > 1:
continue

# check if the conv modules are used in multiple nodes
conv_name = node.args[0].target
bn_name = node.target

conv_usage_count = 0
for _node in fx_model.graph.nodes:
if _node.op != 'call_module':
continue
if _node.target == conv_name:
conv_usage_count += 1

if conv_usage_count > 1:
continue

# Find a pair of conv and bn to optimize
conv_module = modules[conv_name]
bn_module = modules[bn_name]

conv_module.forward = partial(fast_conv_bn_eval_control, bn_module,
conv_module)


def turn_on_fast_conv_bn_eval(model: torch.nn.Module, modules: Union[List[str],
str]):
if isinstance(modules, str):
modules = [modules]
for module_name in modules:
module = attrgetter(module_name)(model)
turn_on_fast_conv_bn_eval_for_single_model(module)
9 changes: 9 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.model.fast_conv_bn_eval import turn_on_fast_conv_bn_eval
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
Expand Down Expand Up @@ -1705,6 +1706,14 @@ def train(self) -> nn.Module:

# initialize the model weights
self._init_model_weights()

# try to enable fast_conv_bn_eval feature
modules = self.cfg.get('fast_conv_bn_eval', None)
if modules is not None:
self.logger.info(f'Enabling the "fast_conv_bn_eval" feature'
f' for sub-modules: {modules}')
turn_on_fast_conv_bn_eval(ori_model, modules)

# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()

Expand Down
57 changes: 57 additions & 0 deletions tests/test_model/test_fast_conv_bn_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch
from torch import nn

from mmengine.model.fast_conv_bn_eval import \
turn_on_fast_conv_bn_eval_for_single_model
from mmengine.testing import assert_allclose
from mmengine.utils import is_installed

mmcv_is_installed = is_installed('mmcv')


class BackboneModel(nn.Module):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if mmcv_is_installed:
from mmcv.cnn import ConvModule
conv0 = nn.Conv2d(6, 6, 6)
bn0 = nn.BatchNorm2d(6)
self.mod1 = ConvModule.create_from_conv_bn(conv0, bn0)
self.conv1 = nn.Conv2d(6, 6, 6)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 6, 6)
self.bn2 = nn.BatchNorm2d(6)
self.conv3 = nn.Conv2d(6, 6, 6)
self.bn3 = nn.BatchNorm2d(6)

def forward(self, x):
if mmcv_is_installed:
# this ConvModule can use fast_conv_bn_eval feature
x = self.mod1(x)
# this conv-bn pair can use fast_conv_bn_eval feature
x = self.bn1(self.conv1(x))
# this conv-bn pair cannot use fast_conv_bn_eval feature
# because `self.conv2` is used twice
x = self.bn2(self.conv2(self.conv2(x)))
# this conv-bn pair can use fast_conv_bn_eval feature
# just for the first forward of the `self.bn3`
x = self.bn3(self.bn3(self.conv3(x)))
return x


class TestFastConvBNEval(TestCase):
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
"""Test the turn_on_fast_conv_bn_eval function."""

def test_fast_conv_bn_eval(self):
model = BackboneModel()
model.eval()
input = torch.randn(64, 6, 32, 32)
output = model(input)
turn_on_fast_conv_bn_eval_for_single_model(model)
output2 = model(input)
print((output - output2).abs().max().item())
assert_allclose(output, output2)