Skip to content
This repository was archived by the owner on May 29, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ jobs:
run: |
sudo apt-get install -y python3-setuptools libopencv-dev
python3 -m pip install --upgrade pip
python3 -m pip install torch==${{ matrix.torch-version }}
python3 -m pip install torch==${{ matrix.torch-version }} torchvision
python3 -m pip install -U protobuf
python3 -m pip install openvino-dev[onnx]==${{env.OPENVINO_VERSION}}
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Repository with guides to enable some layers from PyTorch in Intel OpenVINO:
* [nn.MaxUnpool2d](examples/unpool)
* [torch.fft](examples/fft)
* [nn.functional.grid_sample](https://github.com/dkurt/openvino_pytorch_layers/tree/master/examples/grid_sample)
* [torchvision.ops.DeformConv2d](examples/deformable_conv)


## OpenVINO Model Optimizer extension
Expand Down
51 changes: 51 additions & 0 deletions examples/deformable_conv/deformable_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
import torchvision.ops as ops


class DeformableConvFunc(torch.autograd.Function):
@staticmethod
def symbolic(g, cls, x, offset):
weight = cls.state_dict()["weight"]
weight = g.op("Constant", value_t=weight)

return g.op(
"DeformableConv2D",
x,
offset,
weight,
strides_i=(cls.stride, cls.stride),
pads_i=(cls.padding, cls.padding, cls.padding, cls.padding),
dilations_i=(cls.dilation, cls.dilation),
deformable_group_i=cls.groups,
)

@staticmethod
def forward(self, cls, x, offset):
y = cls.origin_forward(x, offset)
return y


class DeformableConvolution(ops.DeformConv2d):
"""
This is a support class which helps export network with SparseConv in ONNX format.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.origin_forward = super().forward
self.stride = kwargs.get("stride", 1)
self.padding = kwargs.get("padding", 0)
self.dilation = kwargs.get("dilation", 1)
self.groups = kwargs.get("groups", 1)
self.pad_l = nn.ConstantPad2d((1, 1, 1, 1), 0)

def forward(self, x, offset):
"""
Using paddings is a workaround for 2021.4 release.
"""
x = self.pad_l(x)
offset = self.pad_l(offset)
y = DeformableConvFunc.apply(self, x, offset)
y = y[:, :, 1:-1, 1:-1]
return y
111 changes: 111 additions & 0 deletions examples/deformable_conv/export_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from .deformable_conv import DeformableConvolution

np.random.seed(324)
torch.manual_seed(32)


class MyModel(nn.Module):
def __init__(
self,
inplanes,
outplanes,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
bias=False,
deformable_groups=1,
):
super(MyModel, self).__init__()
self.def_conv = DeformableConvolution(
inplanes,
outplanes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=deformable_groups,
)

def forward(self, x, offset):
y = self.def_conv(x, offset)
return y


def export(
inplanes,
outplanes,
kernel_size,
stride,
padding,
dilation,
deformable_groups,
inp_shape,
offset_shape,
):
np.random.seed(324)
torch.manual_seed(32)

model = MyModel(
inplanes,
outplanes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
deformable_groups=deformable_groups,
)
model.eval()

x = Variable(torch.randn(inp_shape))
offset = Variable(torch.randn(offset_shape))
ref = model(x, offset)

np.save("inp", x.detach().numpy())
np.save("inp1", offset.detach().numpy())
np.save("ref", ref.detach().numpy())

with torch.no_grad():
torch.onnx.export(
model,
(x, offset),
"model.onnx",
input_names=["input", "input1"],
output_names=["output"],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
opset_version=12,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate ONNX model and test data")
parser.add_argument("--inp_shape", type=int, nargs="+", default=[1, 15, 128, 240])
parser.add_argument(
"--offset_shape", type=int, nargs="+", default=[1, 18, 128, 240]
)
parser.add_argument("--inplanes", type=int, nargs="+", default=15)
parser.add_argument("--outplanes", type=int, nargs="+", default=15)
parser.add_argument("--kernel_size", type=int, nargs="+", default=3)
parser.add_argument("--stride", type=int, nargs="+", default=1)
parser.add_argument("--padding", type=int, nargs="+", default=1)
parser.add_argument("--dilation", type=int, nargs="+", default=1)
parser.add_argument("--deformable_groups", type=int, nargs="+", default=1)
args = parser.parse_args()

export(
args.inplanes,
args.outplanes,
args.kernel_size,
args.stride,
args.padding,
args.dilation,
args.deformable_groups,
args.inp_shape,
args.offset_shape,
)
1 change: 0 additions & 1 deletion mo_extensions/ops/GridSample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
from mo.graph.graph import Node, Graph
from mo.ops.op import Op

Expand Down
71 changes: 45 additions & 26 deletions run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@

import numpy as np


class TestLayers(unittest.TestCase):
def convert_model(self):
subprocess.run([sys.executable,
'-m',
'mo',
'--input_model=model.onnx',
'--extension', Path(__file__).absolute().parent / 'mo_extensions'],
check=True)
subprocess.run(
[
sys.executable,
"-m",
"mo",
"--input_model=model.onnx",
"--extension",
Path(__file__).absolute().parent / "mo_extensions",
],
check=True,
)

def run_test(self, convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e-5):
if convert_ir and not test_onnx:
Expand All @@ -25,67 +31,63 @@ def run_test(self, convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e-
inputs = {}
shapes = {}
for i in range(num_inputs):
suffix = '{}'.format(i if i > 0 else '')
data = np.load('inp' + suffix + '.npy')
inputs['input' + suffix] = data
shapes['input' + suffix] = data.shape
suffix = "{}".format(i if i > 0 else "")
data = np.load("inp" + suffix + ".npy")
inputs["input" + suffix] = data
shapes["input" + suffix] = data.shape

ref = np.load('ref.npy')
ref = np.load("ref.npy")

ie = IECore()
ie.add_extension(get_extensions_path(), 'CPU')
ie.set_config({'CONFIG_FILE': 'user_ie_extensions/gpu_extensions.xml'}, 'GPU')
ie.add_extension(get_extensions_path(), "CPU")
ie.set_config({"CONFIG_FILE": "user_ie_extensions/gpu_extensions.xml"}, "GPU")

net = ie.read_network('model.onnx' if test_onnx else 'model.xml')
net = ie.read_network("model.onnx" if test_onnx else "model.xml")
net.reshape(shapes)
exec_net = ie.load_network(net, 'CPU')
exec_net = ie.load_network(net, "CPU")

out = exec_net.infer(inputs)
out = next(iter(out.values()))

diff = np.max(np.abs(ref - out))
self.assertLessEqual(diff, threshold)


def test_unpool(self):
from examples.unpool.export_model import export
export(mode='default')
self.run_test()

export(mode="default")
self.run_test()

def test_unpool_reshape(self):
from examples.unpool.export_model import export
export(mode='dynamic_size', shape=[5, 3, 6, 9])

export(mode="dynamic_size", shape=[5, 3, 6, 9])
self.run_test()

export(mode='dynamic_size', shape=[4, 3, 17, 8])
export(mode="dynamic_size", shape=[4, 3, 17, 8])
self.run_test(convert_ir=False)


def test_fft(self):
from examples.fft.export_model import export

for shape in [[5, 120, 2], [4, 240, 320, 2], [3, 5, 240, 320, 2]]:
export(shape=shape)
self.run_test()


def test_fft_roll(self):
from examples.fft.export_model_with_roll import export

export()
self.run_test()
self.run_test(test_onnx=True)


def test_grid_sample(self):
from examples.grid_sample.export_model import export

export()
self.run_test(num_inputs=2)
self.run_test(num_inputs=2, test_onnx=True)


def test_complex_mul(self):
from examples.complex_mul.export_model import export

Expand All @@ -94,6 +96,23 @@ def test_complex_mul(self):
self.run_test(num_inputs=2)
self.run_test(num_inputs=2, test_onnx=True)


if __name__ == '__main__':
def test_deformable_conv(self):
from examples.deformable_conv.export_model import export

export(
inplanes=15,
outplanes=15,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
deformable_groups=1,
inp_shape=[1, 15, 128, 240],
offset_shape=[1, 18, 128, 240],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why 18?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, 2 * offset_groups * kernel_height * kernel_width = 2 * 1 * 3 * 3 = 18

source: https://pytorch.org/vision/master/generated/torchvision.ops.DeformConv2d.html

)
self.run_test(num_inputs=2, threshold=2e-5)
self.run_test(num_inputs=2, test_onnx=True, threshold=2e-5)


if __name__ == "__main__":
unittest.main()