In [1]:
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile


class AnnotatedConvBnReLUModel(torch.nn.Module):
    def __init__(self):
        super(AnnotatedConvBnReLUModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU(inplace=True)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

In [2]:
model = AnnotatedConvBnReLUModel()
print(model)


def prepare_save(model, fused):
    model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    torchscript_model = torch.jit.script(model)
    torchscript_model_optimized = optimize_for_mobile(torchscript_model)
    torch.jit.save(
        torchscript_model_optimized, "model.pt" if not fused else "model_fused.pt"
    )


prepare_save(model, False)

model = AnnotatedConvBnReLUModel()
model_fused = torch.quantization.fuse_modules(model, [["bn", "relu"]], inplace=False)
print(model_fused)

prepare_save(model_fused, True)

AnnotatedConvBnReLUModel(
  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
AnnotatedConvBnReLUModel(
  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BNReLU2d(
    (0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU(inplace=True)
  )
  (relu): Identity()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)




In [None]:
# git clone --recursive https://github.com/pytorch/pytorch
# cd pytorch
# git submodule update --init --recursive
# BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON