Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use graph transform to deal with more general cases for efficient_conv_bn_eval #1259

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 52 additions & 41 deletions mmengine/model/efficient_conv_bn_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from operator import attrgetter
from typing import List, Union

Expand Down Expand Up @@ -58,48 +57,32 @@ def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


def bn_once_identity_forward(bn: nn.modules.batchnorm._BatchNorm,
x: torch.Tensor):
"""The forward function is an identity function.

The magic is that after one call, the `bn.forward` will be restored to what
it used to be.
"""
bn.__dict__.pop('forward')
return x


def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""This function controls whether to use `efficient_conv_bn_eval_forward`.

If the following `bn` is in `eval` mode, then we turn on the special
`efficient_conv_bn_eval_forward` and let the following call of `bn.forward`
to be identity. Note that this `bn.forward` modification only works for one
call. After the call, `bn.forward` will be restored to the default
function. This is to deal with the case where one `bn` module is used in
multiple places.
`efficient_conv_bn_eval_forward`.
"""
if not bn.training:
# bn in eval mode
output = efficient_conv_bn_eval_forward(bn, conv, x)
bn.forward = partial(bn_once_identity_forward, bn)
return output
else:
return conv._conv_forward(x, conv.weight, conv.bias)
conv_out = conv._conv_forward(x, conv.weight, conv.bias)
return bn(conv_out)


def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
# optimize consecutive conv+bn by modifying forward function
# Symbolically trace the input model to create an FX GraphModule
import torch.fx as fx
fx_model: fx.GraphModule = fx.symbolic_trace(model)
def efficient_conv_bn_eval_graph_transform(fx_model):
"""Find consecutive conv+bn calls in the graph, inplace modify the graph
with the fused operation."""
modules = dict(fx_model.named_modules())

patterns = [(torch.nn.modules.conv._ConvNd,
torch.nn.modules.batchnorm._BatchNorm)]

pairs = []
# Iterate through nodes in the graph to find ConvBN blocks
for node in fx_model.graph.nodes:
# If our current node isn't calling a Module then we can ignore it.
Expand All @@ -116,26 +99,54 @@ def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
if not found_pair or len(node.args[0].users) > 1:
continue

# check if the conv modules are used in multiple nodes
conv_name = node.args[0].target
bn_name = node.target

conv_usage_count = 0
for _node in fx_model.graph.nodes:
if _node.op != 'call_module':
continue
if _node.target == conv_name:
conv_usage_count += 1
# Find a pair of conv and bn computation nodes to optimize
conv_node = node.args[0]
bn_node = node
pairs.append([conv_node, bn_node])

for conv_node, bn_node in pairs:
# set insertion point
fx_model.graph.inserting_before(conv_node)
# create `get_attr` node to access modules
# note that we directly call `create_node` to fill the `name`
# argument. `fx_model.graph.get_attr` and
# `fx_model.graph.call_function` does not allow the `name` argument.
conv_get_node = fx_model.graph.create_node(
op='get_attr', target=conv_node.target, name='get_conv')
bn_get_node = fx_model.graph.create_node(
op='get_attr', target=bn_node.target, name='get_bn')
# prepare args for the fused function
args = (bn_get_node, conv_get_node, conv_node.args[0])
# create a new node
new_node = fx_model.graph.create_node(
op='call_function',
target=efficient_conv_bn_eval_control,
args=args,
name='efficient_conv_bn_eval')
# this node replaces the original conv + bn, and therefore
# should replace the uses of bn_node
bn_node.replace_all_uses_with(new_node)
# take care of the deletion order:
# delete bn_node first, and then conv_node
fx_model.graph.erase_node(bn_node)
fx_model.graph.erase_node(conv_node)

# regenerate the code
fx_model.graph.lint()
fx_model.recompile()

if conv_usage_count > 1:
continue

# Find a pair of conv and bn to optimize
conv_module = modules[conv_name]
bn_module = modules[bn_name]
def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
import torch.fx as fx

conv_module.forward = partial(efficient_conv_bn_eval_control,
bn_module, conv_module)
# currently we use `fx.symbolic_trace` to trace models.
# in the future, we might turn to pytorch 2.0 compile infrastructure to
# get the `fx.GraphModule` IR. Nonetheless, the graph transform function
# can remain unchanged. We just need to change the way
# we get `fx.GraphModule`.
fx_model: fx.GraphModule = fx.symbolic_trace(model)
efficient_conv_bn_eval_graph_transform(fx_model)
model.forward = fx_model.forward


def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model/test_efficient_conv_bn_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def forward(self, x):
x = self.mod1(x)
# this conv-bn pair can use efficient_conv_bn_eval feature
x = self.bn1(self.conv1(x))
# this conv-bn pair cannot use efficient_conv_bn_eval feature
# because `self.conv2` is used twice
# this conv-bn pair can use efficient_conv_bn_eval feature
# only for the second `self.conv2` call.
x = self.bn2(self.conv2(self.conv2(x)))
# this conv-bn pair can use efficient_conv_bn_eval feature
# just for the first forward of the `self.bn3`
Expand Down