Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode #2807

Merged
merged 8 commits into from
Jun 13, 2023

Conversation

youkaichao
Copy link
Contributor

Motivation

This PR is motivated by the arxiv paper https://arxiv.org/abs/2305.11624 Tune-Mode ConvBN Blocks For Efficient Transfer Learning. It leverages the associative law between convolution and affine transform, i.e., normalize (weight conv feature) = (normalize weight) conv feature.

It has two advantages:

  1. During inference/validation, the conv-bn calculation can be made faster.
  2. During training with Eval mode, the conv-bn calculation can be made faster and memory efficient.

Modification

The implementation appears as a pre-forward hook registered on the conv layer. It is compatible with the existing implementation. During each forward calculation, it identifies whether the hook should be activated, and then switch to the fast computation.

BC-breaking (Optional)

This should not break any existing code.

Use cases (Optional)

There are two possible use cases:

  1. Define post_build_model hook in MMCV, which is used by default. The hook traces the network (typically only the backbone) to replace consecutive conv and bn with the new ConvModule. This way, downstream users seamlessly enjoy the speedup.

  2. Modify the build_model function for each downstream repo (like mmdetection and mmpose) to trace consecutive conv and bn, replacing them with a new ConvModule.

Checklist

Before PR:

  • I have read and followed the workflow indicated in the CONTRIBUTING.md to create this PR.
  • Pre-commit or linting tools indicated in CONTRIBUTING.md are used to fix the potential lint issues.
  • Bug fixes are covered by unit tests, the case that causes the bug should be added in the unit tests.
  • New functionalities are covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, including docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with some of those projects, like MMDet or MMCls.
  • CLA has been signed and all committers have signed the CLA in this PR.

@youkaichao
Copy link
Contributor Author

Another related implementation is FrozenBatchNorm2d from torchvision and detectron2.

Implementation of this PR is faster than FrozenBatchNorm2d, with almost the same memory cost (significantly less than current ConvModule in mmcv). The table is from the Table 8 of the paper "Tune-Mode ConvBN Blocks For Efficient Transfer Learning":

image

Besides, this PR does not hurt performance, while FrozenBatchNorm2d will. From Table 6 of the MMDetection report, FrozenBatchNorm2d is worse in mAP. While this PR is equivalent with the norm_eval setting.
image

From the Figure 1 of the paper "Tune-Mode ConvBN Blocks For Efficient Transfer Learning", norm_eval is prevalent in MMDetection:

image

Therefore, I think this PR can be a drop-in improvement for mmcv. It automatically identifies the case for possible acceleration with equivalent implementation.

@youkaichao youkaichao changed the title [Feature] add FastConvBN method for fast validation and training in Eval mode [Feature] add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode May 25, 2023
@youkaichao
Copy link
Contributor Author

The implementation is compatible with ONNX export, and since torch.onnx.export uses constant folding by default, conv and bn modules in ConvModule will be automatically fused into one operator.

mmcv/cnn/bricks/conv_module.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/conv_module.py Show resolved Hide resolved
mmcv/cnn/bricks/conv_module.py Outdated Show resolved Hide resolved
mmcv/cnn/bricks/conv_module.py Outdated Show resolved Hide resolved
@youkaichao
Copy link
Contributor Author

youkaichao commented Jun 10, 2023

Here is an example usage:

# Import required libraries
from typing import Tuple
from functools import partial
from operator import attrgetter

import torch
import torch.nn as nn
import torch.fx as fx
from mmcv.cnn import ConvModule


# Helper function to split a qualname into parent path and last atom.
def _parent_name(target : str) -> Tuple[str, str]:
    """
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name


def replace_sub_module(model, name, new_module):
    # Remove the original module from the model
    # usage: replace_sub_module(model, 'layer1.block2.conv2', conv)
    parent_name, name = _parent_name(name)
    if parent_name != '':
        getter = attrgetter(parent_name)
        parent = getter(model)
    else:
        parent = model
    setattr(parent, name, new_module)


# Main function to merge consecutive conv+bn into ConvModule for the given model
def find_and_merge_conv_bn(model: torch.nn.Module):
    # Symbolically trace the input model to create an FX GraphModule
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())

    patterns = [(torch.nn.modules.conv._ConvNd, torch.nn.modules.batchnorm._BatchNorm)]

    # Iterate through nodes in the graph to find ConvBN blocks
    for node in fx_model.graph.nodes:
        if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
            continue
        found_pair = [node for conv_class, bn_class in patterns if isinstance(modules[node.target], bn_class) and isinstance(modules[node.args[0].target], conv_class)]
        if not found_pair or len(node.args[0].users) > 1: # Not a conv-BN pattern or output of conv is used by other nodes
            continue

        # Find a pair of conv and bn to optimize
        conv_name = node.args[0].target
        bn_name = node.target

        print(f'Merging {conv_name} and {bn_name} into a ConvModule')
        conv = modules[conv_name]
        bn = modules[bn_name]

        # Fuse conv and bn into a ConvModule
        new_conv = ConvModule.create_from_conv_bn(conv, bn)
        replace_sub_module(model, conv_name, new_conv)
        replace_sub_module(model, bn_name, nn.Identity())

if __name__ == '__main__':
    import torchvision.models as models
    from copy import deepcopy
    resnet = models.resnet50(pretrained=False)
    resnet.eval()
    resnet2 = deepcopy(resnet)
    resnet2.eval()
    find_and_merge_conv_bn(resnet2)

    resnet.cuda()
    resnet2.cuda()
    input = torch.randn(32, 3, 224, 224).cuda()
    output = resnet(input)
    output2 = resnet2(input)
    print(torch.allclose(output, output2, atol=1e-4))

    del output
    del output2

    import time
    start = time.time()
    # reset pytorch max_memory_allocated
    torch.cuda.reset_max_memory_allocated()
    start_memory = torch.cuda.memory_allocated()
    for i in range(10):
        resnet(input).sum().backward()
    end = time.time()
    max_memory = torch.cuda.max_memory_allocated()
    print(f'time for resnet: {end - start} seconds (10 batches with batch size 32)')
    print(f'max memory for resnet: {(max_memory - start_memory) / 1024 ** 3} GB')

    start = time.time()
    # reset pytorch max_memory_allocated
    torch.cuda.reset_max_memory_allocated()
    start_memory = torch.cuda.memory_allocated()
    for i in range(10):
        resnet2(input).sum().backward()
    end = time.time()
    max_memory = torch.cuda.max_memory_allocated()
    print(f'time for resnet with ConvModule: {end - start} seconds (10 batches with batch size 32)')
    print(f'max memory for resnet with ConvModule: {(max_memory - start_memory) / 1024 ** 3} GB')

On my server with RTX 2080 Ti GPU, the output is :

... ... other logs omitted ... ...
time for resnet: 2.198728322982788 seconds (10 batches with batch size 32)
max memory for resnet: 2.703915596008301 GB
time for resnet with ConvModule: 1.1381397247314453 seconds (10 batches with batch size 32)
max memory for resnet with ConvModule: 1.4728455543518066 GB

Merging conv and bn into a ConvModule with fast_conv_bn_eval=True reduces half memory usage and also reduces about half the wallclock time for forward and backward computation.

Update: I re-run the test, with the following results:

time for resnet: 1.157458782196045 seconds (10 batches with batch size 32)
max memory for resnet: 2.703915596008301 GB
time for resnet with ConvModule: 1.133033037185669 seconds (10 batches with batch size 32)
max memory for resnet with ConvModule: 1.4728455543518066 GB

The memory reduction is obvious, but the time reduction is not that obvious. The wallclock time can vary from time to time. It is not very stable.

@HAOCHENYE
Copy link
Collaborator

Thanks for your guidance, and I've test it with retinanet_r50

The training accuracy matches the original result well:

DONE (t=10.26s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.365
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.555
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.389
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.205
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.400
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.481
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.538
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.538
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.538
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.333
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.582
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.691
06/12 13:47:05 - mmengine - INFO - bbox_mAP_copypaste: 0.365 0.555 0.389 0.205 0.400 0.481
06/12 13:47:06 - mmengine - INFO - Epoch(val) [12][625/625]    coco/bbox_mAP: 0.3650  coco/bbox_mAP_50: 0.5550  coco/bbox_mAP_75: 0.3890  coco/bbox_mAP_s: 0.2050  coco/bbox_mAP_m: 0.4000  coco/bbox_mAP_l: 0.4810  data_time: 0.0019  time: 0.0268

Besides, the memory optimization is also obvious:

Result of fast conv-bn

06/12 11:21:01 - mmengine - INFO - Epoch(train)  [1][  50/7330]  lr: 9.9098e-04  eta: 7:08:46  time: 0.2926  data_time: 0.0063  memory: 2430  loss: 1.9231  loss_cls: 1.2129  loss_bbox: 0.7102
06/12 11:21:07 - mmengine - INFO - Epoch(train)  [1][ 100/7330]  lr: 1.9920e-03  eta: 5:04:01  time: 0.1226  data_time: 0.0037  memory: 2428  loss: 1.9257  loss_cls: 1.2262  loss_bbox: 0.6995
06/12 11:21:13 - mmengine - INFO - Epoch(train)  [1][ 150/7330]  lr: 2.9930e-03  eta: 4:22:55  time: 0.1237  data_time: 0.0039  memory: 2431  loss: 1.8945  loss_cls: 1.2170  loss_bbox: 0.6775
06/12 11:21:19 - mmengine - INFO - Epoch(train)  [1][ 200/7330]  lr: 3.9940e-03  eta: 4:00:36  time: 0.1190  data_time: 0.0041  memory: 2429  loss: 1.8793  loss_cls: 1.1972  loss_bbox: 0.6821
06/12 11:21:25 - mmengine - INFO - Epoch(train)  [1][ 250/7330]  lr: 4.9950e-03  eta: 3:43:07  time: 0.1052  data_time: 0.0039  memory: 2431  loss: 1.7925  loss_cls: 1.1160  loss_bbox: 0.6765
06/12 11:21:30 - mmengine - INFO - Epoch(train)  [1][ 300/7330]  lr: 5.9960e-03  eta: 3:32:09  time: 0.1081  data_time: 0.0036  memory: 2429  loss: 1.7069  loss_cls: 1.0512  loss_bbox: 0.6557
06/12 11:21:35 - mmengine - INFO - Epoch(train)  [1][ 350/7330]  lr: 6.9970e-03  eta: 3:22:29  time: 0.0995  data_time: 0.0034  memory: 2429  loss: 1.6845  loss_cls: 1.0676  loss_bbox: 0.6168
06/12 11:21:40 - mmengine - INFO - Epoch(train)  [1][ 400/7330]  lr: 7.9980e-03  eta: 3:13:55  time: 0.0923  data_time: 0.0034  memory: 2429  loss: 1.7729  loss_cls: 1.1561  loss_bbox: 0.6168
06/12 11:21:45 - mmengine - INFO - Epoch(train)  [1][ 450/7330]  lr: 8.9990e-03  eta: 3:08:30  time: 0.1001  data_time: 0.0034  memory: 2429  loss: 1.7027  loss_cls: 1.0813  loss_bbox: 0.6215
06/12 11:21:49 - mmengine - INFO - Epoch(train)  [1][ 500/7330]  lr: 1.0000e-02  eta: 3:03:07  time: 0.0931  data_time: 0.0034  memory: 2430  loss: 1.6259  loss_cls: 1.0479  loss_bbox: 0.5780
06/12 11:21:54 - mmengine - INFO - Epoch(train)  [1][ 550/7330]  lr: 1.0000e-02  eta: 2:58:32  time: 0.0919  data_time: 0.0034  memory: 2432  loss: 1.7480  loss_cls: 1.1371  loss_bbox: 0.6109

Result of normal conv-bn

2023/06/12 11:14:35 - mmengine - INFO - Epoch(train)  [1][  50/7330]  lr: 9.9098e-04  eta: 6:11:53  time: 0.2538  data_time: 0.0058  memory: 3306  loss: 1.9298  loss_cls: 1.2243  loss_bbox: 0.7054
2023/06/12 11:14:41 - mmengine - INFO - Epoch(train)  [1][ 100/7330]  lr: 1.9920e-03  eta: 4:33:11  time: 0.1193  data_time: 0.0033  memory: 3303  loss: 1.8993  loss_cls: 1.2241  loss_bbox: 0.6752
2023/06/12 11:14:47 - mmengine - INFO - Epoch(train)  [1][ 150/7330]  lr: 2.9930e-03  eta: 3:53:54  time: 0.1064  data_time: 0.0033  memory: 3301  loss: 1.9179  loss_cls: 1.2271  loss_bbox: 0.6908
2023/06/12 11:14:52 - mmengine - INFO - Epoch(train)  [1][ 200/7330]  lr: 3.9940e-03  eta: 3:35:22  time: 0.1095  data_time: 0.0032  memory: 3306  loss: 1.9083  loss_cls: 1.2412  loss_bbox: 0.6671
2023/06/12 11:14:57 - mmengine - INFO - Epoch(train)  [1][ 250/7330]  lr: 4.9950e-03  eta: 3:20:48  time: 0.0978  data_time: 0.0033  memory: 3305  loss: 1.7696  loss_cls: 1.1125  loss_bbox: 0.6571
2023/06/12 11:15:02 - mmengine - INFO - Epoch(train)  [1][ 300/7330]  lr: 5.9960e-03  eta: 3:11:37  time: 0.1001  data_time: 0.0033  memory: 3302  loss: 1.6912  loss_cls: 1.0489  loss_bbox: 0.6424
2023/06/12 11:15:07 - mmengine - INFO - Epoch(train)  [1][ 350/7330]  lr: 6.9970e-03  eta: 3:04:33  time: 0.0978  data_time: 0.0033  memory: 3307  loss: 1.6070  loss_cls: 0.9797  loss_bbox: 0.6273
2023/06/12 11:15:12 - mmengine - INFO - Epoch(train)  [1][ 400/7330]  lr: 7.9980e-03  eta: 2:58:13  time: 0.0922  data_time: 0.0033  memory: 3303  loss: 1.7249  loss_cls: 1.1164  loss_bbox: 0.6085
2023/06/12 11:15:16 - mmengine - INFO - Epoch(train)  [1][ 450/7330]  lr: 8.9990e-03  eta: 2:53:47  time: 0.0954  data_time: 0.0033  memory: 3302  loss: 1.5828  loss_cls: 0.9846  loss_bbox: 0.5982
2023/06/12 11:15:21 - mmengine - INFO - Epoch(train)  [1][ 500/7330]  lr: 1.0000e-02  eta: 2:50:09  time: 0.0949  data_time: 0.0033  memory: 3304  loss: 1.4992  loss_cls: 0.9300  loss_bbox: 0.5692
2023/06/12 11:15:26 - mmengine - INFO - Epoch(train)  [1][ 550/7330]  lr: 1.0000e-02  eta: 2:47:14  time: 0.0954  data_time: 0.0032  memory: 3302  loss: 1.4713  loss_cls: 0.9419  loss_bbox: 0.5294

The memory allocated is optimized from 3300 to 2432

@zhouzaida zhouzaida changed the title [Feature] add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode [Feature] Add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode Jun 13, 2023
@zhouzaida zhouzaida merged commit 36003b7 into open-mmlab:main Jun 13, 2023
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants