In [11]:
from torch.nn.utils.fusion import fuse_conv_bn_eval
from torch.fx.node import Argument, Target
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import copy

In [15]:
import torch 
from torch import fx

In [16]:
from torch.fx import symbolic_trace

In [28]:
def _parent_name(target : str) -> Tuple[str, str]:
    """
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name 

In [9]:
_parent_name("foo.bar.baz")

('foo', 'bar', 'baz')

In [27]:
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
    if len(node.args) == 0:
        return False
    
    nodes: Tuple[Any, fx.Node] = (node.args[0], node)
    for expected_type,current_node in zip(pattern,nodes):
        if not isinstance(current_node,fx.Node):
            return False
        if current_node.op !="call_module":
            return False
        if not isinstance(current_node.target,str):
            return False
        if current_node.target not in modules:
            return False
        if type(modules[current_node.target]) is not expected_type:
            return False
    return False

In [30]:
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    modules[node.target] = new_module
    setattr(modules[parent_name], name, new_module)

In [34]:
import torch 
import torch.nn as nn

In [35]:
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    """
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    """
    patterns = [(nn.Conv1d, nn.BatchNorm1d),
                (nn.Conv2d, nn.BatchNorm2d),
                (nn.Conv3d, nn.BatchNorm3d)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)

    for pattern in patterns:
        for node in new_graph.nodes:
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
                    continue
                conv = modules[node.args[0].target]
                bn = modules[node.target]
                fused_conv = fuse_conv_bn_eval(conv, bn)
                replace_node_module(node.args[0], modules, fused_conv)
                node.replace_all_uses_with(node.args[0])
                new_graph.erase_node(node)
    return fx.GraphModule(fx_model, new_graph)

def remove_dropout(model: nn.Module) -> nn.Module:
    """
    Removes all dropout layers from the module.
    """
    fx_model = fx.symbolic_trace(model)

    class DropoutRemover(torch.fx.Transformer):
        def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
            if isinstance(self.submodules[target], nn.Dropout):
                assert len(args) == 1
                return args[0]
            else:
                return super().call_module(target, args, kwargs)
    return DropoutRemover(fx_model).transform()

In [36]:
class TestConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs):
        super(TestConv2d,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,**kwargs)
        self.bn=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(True)
        
    def forward(self,x):
        x=self.conv(x)
        x=self.bn(x)
        x=self.relu(x)
        return x

In [43]:
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=TestConv2d(3,32,kernel_size=3)
        print(self.conv1)
        self.conv2=TestConv2d(32,64,kernel_size=3)
        self.dropout=nn.Dropout(0.3)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.dropout(x)
        return x

In [38]:
def show(string,count):
    print(f"{'='*count}{string}{'='*count}")

In [44]:
test_model=TestModel()

TestConv2d(
  (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)


In [40]:
test_model.eval()

TestModel(
  (conv1): TestConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv2): TestConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (dropout): Dropout(p=0.3, inplace=False)
)

In [49]:
origin_model=symbolic_trace(test_model)
show("origin result",20)
print(origin_model.graph)
print(origin_model.code)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1_conv : [num_users=1] = call_module[target=conv1.conv](args = (%x,), kwargs = {})
    %conv1_bn : [num_users=1] = call_module[target=conv1.bn](args = (%conv1_conv,), kwargs = {})
    %conv1_relu : [num_users=1] = call_module[target=conv1.relu](args = (%conv1_bn,), kwargs = {})
    %conv2_conv : [num_users=1] = call_module[target=conv2.conv](args = (%conv1_relu,), kwargs = {})
    %conv2_bn : [num_users=1] = call_module[target=conv2.bn](args = (%conv2_conv,), kwargs = {})
    %conv2_relu : [num_users=1] = call_module[target=conv2.relu](args = (%conv2_bn,), kwargs = {})
    %dropout : [num_users=1] = call_module[target=dropout](args = (%conv2_relu,), kwargs = {})
    return dropout



def forward(self, x):
    conv1_conv = self.conv1.conv(x);  x = None
    conv1_bn = self.conv1.bn(conv1_conv);  conv1_conv = None
    conv1_relu = self.conv1.relu(conv1_bn);  conv1_bn = None
    conv2_conv = self.conv2.conv(conv1_relu);  conv

In [50]:
fuse_model=fuse(test_model)
fuse_model=remove_dropout(fuse_model)
show("fuse result",20)
print(fuse_model.graph)
print(fuse_model.code)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1_conv : [num_users=1] = call_module[target=conv1.conv](args = (%x,), kwargs = {})
    %conv1_bn : [num_users=1] = call_module[target=conv1.bn](args = (%conv1_conv,), kwargs = {})
    %conv1_relu : [num_users=1] = call_module[target=conv1.relu](args = (%conv1_bn,), kwargs = {})
    %conv2_conv : [num_users=1] = call_module[target=conv2.conv](args = (%conv1_relu,), kwargs = {})
    %conv2_bn : [num_users=1] = call_module[target=conv2.bn](args = (%conv2_conv,), kwargs = {})
    %conv2_relu : [num_users=1] = call_module[target=conv2.relu](args = (%conv2_bn,), kwargs = {})
    return conv2_relu



def forward(self, x):
    conv1_conv = self.conv1.conv(x);  x = None
    conv1_bn = self.conv1.bn(conv1_conv);  conv1_conv = None
    conv1_relu = self.conv1.relu(conv1_bn);  conv1_bn = None
    conv2_conv = self.conv2.conv(conv1_relu);  conv1_relu = None
    conv2_bn = self.conv2.bn(conv2_conv);  conv2_conv = None
    conv2_relu = 