diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 036d491..fb4c7d0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -149,7 +149,7 @@ jobs: run: | sudo apt-get install -y python3-setuptools libopencv-dev python3 -m pip install --upgrade pip - python3 -m pip install torch==${{ matrix.torch-version }} + python3 -m pip install torch==${{ matrix.torch-version }} torchvision python3 -m pip install -U protobuf python3 -m pip install openvino-dev[onnx]==${{env.OPENVINO_VERSION}} diff --git a/README.md b/README.md index dbe654a..23269bb 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ Repository with guides to enable some layers from PyTorch in Intel OpenVINO: * [nn.MaxUnpool2d](examples/unpool) * [torch.fft](examples/fft) * [nn.functional.grid_sample](https://github.com/dkurt/openvino_pytorch_layers/tree/master/examples/grid_sample) +* [torchvision.ops.DeformConv2d](examples/deformable_conv) ## OpenVINO Model Optimizer extension diff --git a/examples/deformable_conv/deformable_conv.py b/examples/deformable_conv/deformable_conv.py new file mode 100644 index 0000000..fce9fa6 --- /dev/null +++ b/examples/deformable_conv/deformable_conv.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +import torchvision.ops as ops + + +class DeformableConvFunc(torch.autograd.Function): + @staticmethod + def symbolic(g, cls, x, offset): + weight = cls.state_dict()["weight"] + weight = g.op("Constant", value_t=weight) + + return g.op( + "DeformableConv2D", + x, + offset, + weight, + strides_i=(cls.stride, cls.stride), + pads_i=(cls.padding, cls.padding, cls.padding, cls.padding), + dilations_i=(cls.dilation, cls.dilation), + deformable_group_i=cls.groups, + ) + + @staticmethod + def forward(self, cls, x, offset): + y = cls.origin_forward(x, offset) + return y + + +class DeformableConvolution(ops.DeformConv2d): + """ + This is a support class which helps export network with SparseConv in ONNX format. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.origin_forward = super().forward + self.stride = kwargs.get("stride", 1) + self.padding = kwargs.get("padding", 0) + self.dilation = kwargs.get("dilation", 1) + self.groups = kwargs.get("groups", 1) + self.pad_l = nn.ConstantPad2d((1, 1, 1, 1), 0) + + def forward(self, x, offset): + """ + Using paddings is a workaround for 2021.4 release. + """ + x = self.pad_l(x) + offset = self.pad_l(offset) + y = DeformableConvFunc.apply(self, x, offset) + y = y[:, :, 1:-1, 1:-1] + return y diff --git a/examples/deformable_conv/export_model.py b/examples/deformable_conv/export_model.py new file mode 100644 index 0000000..a7630ad --- /dev/null +++ b/examples/deformable_conv/export_model.py @@ -0,0 +1,111 @@ +import numpy as np +import argparse +import torch +import torch.nn as nn +from torch.autograd import Variable +from .deformable_conv import DeformableConvolution + +np.random.seed(324) +torch.manual_seed(32) + + +class MyModel(nn.Module): + def __init__( + self, + inplanes, + outplanes, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + bias=False, + deformable_groups=1, + ): + super(MyModel, self).__init__() + self.def_conv = DeformableConvolution( + inplanes, + outplanes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=deformable_groups, + ) + + def forward(self, x, offset): + y = self.def_conv(x, offset) + return y + + +def export( + inplanes, + outplanes, + kernel_size, + stride, + padding, + dilation, + deformable_groups, + inp_shape, + offset_shape, +): + np.random.seed(324) + torch.manual_seed(32) + + model = MyModel( + inplanes, + outplanes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + deformable_groups=deformable_groups, + ) + model.eval() + + x = Variable(torch.randn(inp_shape)) + offset = Variable(torch.randn(offset_shape)) + ref = model(x, offset) + + np.save("inp", x.detach().numpy()) + np.save("inp1", offset.detach().numpy()) + np.save("ref", ref.detach().numpy()) + + with torch.no_grad(): + torch.onnx.export( + model, + (x, offset), + "model.onnx", + input_names=["input", "input1"], + output_names=["output"], + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, + opset_version=12, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate ONNX model and test data") + parser.add_argument("--inp_shape", type=int, nargs="+", default=[1, 15, 128, 240]) + parser.add_argument( + "--offset_shape", type=int, nargs="+", default=[1, 18, 128, 240] + ) + parser.add_argument("--inplanes", type=int, nargs="+", default=15) + parser.add_argument("--outplanes", type=int, nargs="+", default=15) + parser.add_argument("--kernel_size", type=int, nargs="+", default=3) + parser.add_argument("--stride", type=int, nargs="+", default=1) + parser.add_argument("--padding", type=int, nargs="+", default=1) + parser.add_argument("--dilation", type=int, nargs="+", default=1) + parser.add_argument("--deformable_groups", type=int, nargs="+", default=1) + args = parser.parse_args() + + export( + args.inplanes, + args.outplanes, + args.kernel_size, + args.stride, + args.padding, + args.dilation, + args.deformable_groups, + args.inp_shape, + args.offset_shape, + ) diff --git a/mo_extensions/ops/GridSample.py b/mo_extensions/ops/GridSample.py index cb224b5..b78b994 100644 --- a/mo_extensions/ops/GridSample.py +++ b/mo_extensions/ops/GridSample.py @@ -1,4 +1,3 @@ -import numpy as np from mo.graph.graph import Node, Graph from mo.ops.op import Op diff --git a/run_tests.py b/run_tests.py index fd7675d..96c04ab 100644 --- a/run_tests.py +++ b/run_tests.py @@ -9,14 +9,20 @@ import numpy as np + class TestLayers(unittest.TestCase): def convert_model(self): - subprocess.run([sys.executable, - '-m', - 'mo', - '--input_model=model.onnx', - '--extension', Path(__file__).absolute().parent / 'mo_extensions'], - check=True) + subprocess.run( + [ + sys.executable, + "-m", + "mo", + "--input_model=model.onnx", + "--extension", + Path(__file__).absolute().parent / "mo_extensions", + ], + check=True, + ) def run_test(self, convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e-5): if convert_ir and not test_onnx: @@ -25,20 +31,20 @@ def run_test(self, convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e- inputs = {} shapes = {} for i in range(num_inputs): - suffix = '{}'.format(i if i > 0 else '') - data = np.load('inp' + suffix + '.npy') - inputs['input' + suffix] = data - shapes['input' + suffix] = data.shape + suffix = "{}".format(i if i > 0 else "") + data = np.load("inp" + suffix + ".npy") + inputs["input" + suffix] = data + shapes["input" + suffix] = data.shape - ref = np.load('ref.npy') + ref = np.load("ref.npy") ie = IECore() - ie.add_extension(get_extensions_path(), 'CPU') - ie.set_config({'CONFIG_FILE': 'user_ie_extensions/gpu_extensions.xml'}, 'GPU') + ie.add_extension(get_extensions_path(), "CPU") + ie.set_config({"CONFIG_FILE": "user_ie_extensions/gpu_extensions.xml"}, "GPU") - net = ie.read_network('model.onnx' if test_onnx else 'model.xml') + net = ie.read_network("model.onnx" if test_onnx else "model.xml") net.reshape(shapes) - exec_net = ie.load_network(net, 'CPU') + exec_net = ie.load_network(net, "CPU") out = exec_net.infer(inputs) out = next(iter(out.values())) @@ -46,22 +52,21 @@ def run_test(self, convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e- diff = np.max(np.abs(ref - out)) self.assertLessEqual(diff, threshold) - def test_unpool(self): from examples.unpool.export_model import export - export(mode='default') - self.run_test() + export(mode="default") + self.run_test() def test_unpool_reshape(self): from examples.unpool.export_model import export - export(mode='dynamic_size', shape=[5, 3, 6, 9]) + + export(mode="dynamic_size", shape=[5, 3, 6, 9]) self.run_test() - export(mode='dynamic_size', shape=[4, 3, 17, 8]) + export(mode="dynamic_size", shape=[4, 3, 17, 8]) self.run_test(convert_ir=False) - def test_fft(self): from examples.fft.export_model import export @@ -69,7 +74,6 @@ def test_fft(self): export(shape=shape) self.run_test() - def test_fft_roll(self): from examples.fft.export_model_with_roll import export @@ -77,7 +81,6 @@ def test_fft_roll(self): self.run_test() self.run_test(test_onnx=True) - def test_grid_sample(self): from examples.grid_sample.export_model import export @@ -85,7 +88,6 @@ def test_grid_sample(self): self.run_test(num_inputs=2) self.run_test(num_inputs=2, test_onnx=True) - def test_complex_mul(self): from examples.complex_mul.export_model import export @@ -94,6 +96,23 @@ def test_complex_mul(self): self.run_test(num_inputs=2) self.run_test(num_inputs=2, test_onnx=True) - -if __name__ == '__main__': + def test_deformable_conv(self): + from examples.deformable_conv.export_model import export + + export( + inplanes=15, + outplanes=15, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + deformable_groups=1, + inp_shape=[1, 15, 128, 240], + offset_shape=[1, 18, 128, 240], + ) + self.run_test(num_inputs=2, threshold=2e-5) + self.run_test(num_inputs=2, test_onnx=True, threshold=2e-5) + + +if __name__ == "__main__": unittest.main()