In [1]:
import os
import os.path as osp
import pprint
import random

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import skimage.io
import skimage.transform
import torch
import yaml
from docopt import docopt

import lcnn
from lcnn.config import C, M
from lcnn.models.line_vectorizer import LineVectorizer
from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
from lcnn.postprocess import postprocess
from lcnn.utils import recursive_to

In [2]:
# args = docopt(__doc__)
config_file = "config/wireframe.yaml"
C.update(C.from_yaml(filename=config_file))
M.update(C.model)
pprint.pprint(C, indent=4)

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device_name = "cpu"
# os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
if torch.cuda.is_available():
    device_name = "cuda"
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed(0)
    print("Let's use", torch.cuda.device_count(), "GPU(s)!")
else:
    print("CUDA is not available")
device = torch.device(device_name)
checkpoint = torch.load("190418-201834-f8934c6-lr4d10-312k.pth.tar", 
                        map_location=device)

{   'io': {   'datadir': 'data/wireframe/',
              'logdir': 'logs/',
              'num_workers': 4,
              'resume_from': None,
              'tensorboard_port': 0,
              'validation_interval': 24000},
    'model': {   'backbone': 'stacked_hourglass',
                 'batch_size': 6,
                 'batch_size_eval': 2,
                 'depth': 4,
                 'dim_fc': 1024,
                 'dim_loi': 128,
                 'eval_junc_thres': 0.008,
                 'head_size': <BoxList: [[2], [1], [2]]>,
                 'image': {   'mean': <BoxList: [109.73, 103.832, 98.681]>,
                              'stddev': <BoxList: [22.275, 22.124, 23.229]>},
                 'loss_weight': {   'jmap': 8.0,
                                    'joff': 0.25,
                                    'lmap': 0.5,
                                    'lneg': 1,
                                    'lpos': 1},
                 'n_dyn_junc': 300,
                 'n_dy

In [3]:
# loading model
hourglass = lcnn.models.hg(
    depth=M.depth,
    head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
    num_stacks=M.num_stacks,
    num_blocks=M.num_blocks,
    num_classes=sum(sum(M.head_size, [])),
)
multitask = MultitaskLearner(hourglass)
linevec = LineVectorizer(multitask)
linevec.load_state_dict(checkpoint["model_state_dict"])
linevec = linevec.to(device)
linevec.eval()
print("Evaluated")

Evaluated


In [4]:
# Grabbing random image
im = skimage.io.imread("/home/fcr/Pictures/room.jpg")
if im.ndim == 2:
    im = np.repeat(im[:, :, None], 3, 2)
im = im[:, :, :3]
im_resized = skimage.transform.resize(im, (512, 512)) * 255
image = (im_resized - M.image.mean) / M.image.stddev
image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
image.shape

torch.Size([1, 3, 512, 512])

In [5]:
input_dict = {
    "image": image.to(device),
    "meta": [
        {
            "junc": torch.zeros(1, 2).to(device),
            "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
            "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
            "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
        }
    ],
    "target": {
        "jmap": torch.zeros([1, 1, 128, 128]).to(device),
        "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
    },
    "mode": "testing",
}

In [6]:
hourglass.num_stacks

2

In [7]:
with torch.no_grad():
    input_dict = {
        "image": image.to(device),
        "meta": [
            {
                "junc": torch.zeros(1, 2).to(device),
                "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
                "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
                "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
            }
        ],
        "target": {
            "jmap": torch.zeros([1, 1, 128, 128]).to(device),
            "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
        },
        "mode": "testing",
    }
    model_out = linevec(input_dict)

torch.Size([1, 256, 128, 128])


## Exporting Multitask

Optimizing Hourglass Model:

```
python3 /opt/intel/openvino/deployment_tools/model_optimizer/mo.py --input_model hg.onnx --model_name hg_fp16 --data_type FP16 --output "1306,1021,1296"
```

Optimizing Multitask Model:

```
python3 /opt/intel/openvino/deployment_tools/model_optimizer/mo.py --input_model ../multitask.onnx --model_name multitask_fp16 --data_type FP16 --input "input.1,5" --output "1302,1370,1371,1375" --mean_values "input.1[109.73,103.832,98.681]" --scale_values "input.1[22.275,22.124,23.229]"
```

In [7]:
torch.onnx.export(hourglass, input_dict['image'], "hg.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator. 
  "" + str(_export_onnx_opset_version) + ". "


graph(%input.1 : Float(1, 3, 512, 512),
      %conv1.weight : Float(64, 3, 7, 7),
      %conv1.bias : Float(64),
      %bn1.weight : Float(64),
      %bn1.bias : Float(64),
      %bn1.running_mean : Float(64),
      %bn1.running_var : Float(64),
      %layer1.0.bn1.weight : Float(64),
      %layer1.0.bn1.bias : Float(64),
      %layer1.0.bn1.running_mean : Float(64),
      %layer1.0.bn1.running_var : Float(64),
      %layer1.0.conv1.weight : Float(64, 64, 1, 1),
      %layer1.0.conv1.bias : Float(64),
      %layer1.0.bn2.weight : Float(64),
      %layer1.0.bn2.bias : Float(64),
      %layer1.0.bn2.running_mean : Float(64),
      %layer1.0.bn2.running_var : Float(64),
      %layer1.0.conv2.weight : Float(64, 64, 3, 3),
      %layer1.0.conv2.bias : Float(64),
      %layer1.0.bn3.weight : Float(64),
      %layer1.0.bn3.bias : Float(64),
      %layer1.0.bn3.running_mean : Float(64),
      %layer1.0.bn3.running_var : Float(64),
      %layer1.0.conv3.weight : Float(128, 64, 1, 1),
      %lay

In [9]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.hg[0], input_hg, "hg_11.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator. 
  "" + str(_export_onnx_opset_version) + ". "


graph(%input.1 : Float(1, 256, 128, 128),
      %hg.0.0.0.bn1.weight : Float(256),
      %hg.0.0.0.bn1.bias : Float(256),
      %hg.0.0.0.bn1.running_mean : Float(256),
      %hg.0.0.0.bn1.running_var : Float(256),
      %hg.0.0.0.conv1.weight : Float(128, 256, 1, 1),
      %hg.0.0.0.conv1.bias : Float(128),
      %hg.0.0.0.bn2.weight : Float(128),
      %hg.0.0.0.bn2.bias : Float(128),
      %hg.0.0.0.bn2.running_mean : Float(128),
      %hg.0.0.0.bn2.running_var : Float(128),
      %hg.0.0.0.conv2.weight : Float(128, 128, 3, 3),
      %hg.0.0.0.conv2.bias : Float(128),
      %hg.0.0.0.bn3.weight : Float(128),
      %hg.0.0.0.bn3.bias : Float(128),
      %hg.0.0.0.bn3.running_mean : Float(128),
      %hg.0.0.0.bn3.running_var : Float(128),
      %hg.0.0.0.conv3.weight : Float(256, 128, 1, 1),
      %hg.0.0.0.conv3.bias : Float(256),
      %hg.0.1.0.bn1.weight : Float(256),
      %hg.0.1.0.bn1.bias : Float(256),
      %hg.0.1.0.bn1.running_mean : Float(256),
      %hg.0.1.0.bn1.running

In [10]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.hg[1], input_hg, "hg_12.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %hg.0.0.0.bn1.weight : Float(256),
      %hg.0.0.0.bn1.bias : Float(256),
      %hg.0.0.0.bn1.running_mean : Float(256),
      %hg.0.0.0.bn1.running_var : Float(256),
      %hg.0.0.0.conv1.weight : Float(128, 256, 1, 1),
      %hg.0.0.0.conv1.bias : Float(128),
      %hg.0.0.0.bn2.weight : Float(128),
      %hg.0.0.0.bn2.bias : Float(128),
      %hg.0.0.0.bn2.running_mean : Float(128),
      %hg.0.0.0.bn2.running_var : Float(128),
      %hg.0.0.0.conv2.weight : Float(128, 128, 3, 3),
      %hg.0.0.0.conv2.bias : Float(128),
      %hg.0.0.0.bn3.weight : Float(128),
      %hg.0.0.0.bn3.bias : Float(128),
      %hg.0.0.0.bn3.running_mean : Float(128),
      %hg.0.0.0.bn3.running_var : Float(128),
      %hg.0.0.0.conv3.weight : Float(256, 128, 1, 1),
      %hg.0.0.0.conv3.bias : Float(256),
      %hg.0.1.0.bn1.weight : Float(256),
      %hg.0.1.0.bn1.bias : Float(256),
      %hg.0.1.0.bn1.running_mean : Float(256),
      %hg.0.1.0.bn1.running

In [11]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.res[0], input_hg, "hg_res_11.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %0.bn1.weight : Float(256),
      %0.bn1.bias : Float(256),
      %0.bn1.running_mean : Float(256),
      %0.bn1.running_var : Float(256),
      %0.conv1.weight : Float(128, 256, 1, 1),
      %0.conv1.bias : Float(128),
      %0.bn2.weight : Float(128),
      %0.bn2.bias : Float(128),
      %0.bn2.running_mean : Float(128),
      %0.bn2.running_var : Float(128),
      %0.conv2.weight : Float(128, 128, 3, 3),
      %0.conv2.bias : Float(128),
      %0.bn3.weight : Float(128),
      %0.bn3.bias : Float(128),
      %0.bn3.running_mean : Float(128),
      %0.bn3.running_var : Float(128),
      %0.conv3.weight : Float(256, 128, 1, 1),
      %0.conv3.bias : Float(256)):
  %22 : Float(1, 256, 128, 128) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%input.1, %0.bn1.weight, %0.bn1.bias, %0.bn1.running_mean, %0.bn1.running_var) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:1670:0
  %23 : Float(1, 256, 

In [12]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.res[1], input_hg, "hg_res_12.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %0.bn1.weight : Float(256),
      %0.bn1.bias : Float(256),
      %0.bn1.running_mean : Float(256),
      %0.bn1.running_var : Float(256),
      %0.conv1.weight : Float(128, 256, 1, 1),
      %0.conv1.bias : Float(128),
      %0.bn2.weight : Float(128),
      %0.bn2.bias : Float(128),
      %0.bn2.running_mean : Float(128),
      %0.bn2.running_var : Float(128),
      %0.conv2.weight : Float(128, 128, 3, 3),
      %0.conv2.bias : Float(128),
      %0.bn3.weight : Float(128),
      %0.bn3.bias : Float(128),
      %0.bn3.running_mean : Float(128),
      %0.bn3.running_var : Float(128),
      %0.conv3.weight : Float(256, 128, 1, 1),
      %0.conv3.bias : Float(256)):
  %22 : Float(1, 256, 128, 128) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%input.1, %0.bn1.weight, %0.bn1.bias, %0.bn1.running_mean, %0.bn1.running_var) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:1670:0
  %23 : Float(1, 256, 

In [13]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.fc[0], input_hg, "hg_fc_11.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %0.weight : Float(256, 256, 1, 1),
      %0.bias : Float(256),
      %1.weight : Float(256),
      %1.bias : Float(256),
      %1.running_mean : Float(256),
      %1.running_var : Float(256)):
  %8 : Float(1, 256, 128, 128) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]](%input.1, %0.weight, %0.bias) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0
  %9 : Float(1, 256, 128, 128) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%8, %1.weight, %1.bias, %1.running_mean, %1.running_var) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:1670:0
  %10 : Float(1, 256, 128, 128) = onnx::Relu(%9) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:912:0
  return (%10)



In [14]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.fc[1], input_hg, "hg_fc_12.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %0.weight : Float(256, 256, 1, 1),
      %0.bias : Float(256),
      %1.weight : Float(256),
      %1.bias : Float(256),
      %1.running_mean : Float(256),
      %1.running_var : Float(256)):
  %8 : Float(1, 256, 128, 128) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]](%input.1, %0.weight, %0.bias) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0
  %9 : Float(1, 256, 128, 128) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%8, %1.weight, %1.bias, %1.running_mean, %1.running_var) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:1670:0
  %10 : Float(1, 256, 128, 128) = onnx::Relu(%9) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:912:0
  return (%10)



In [15]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.score[0], input_hg, "hg_score_1.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %heads.0.0.weight : Float(64, 256, 3, 3),
      %heads.0.0.bias : Float(64),
      %heads.0.2.weight : Float(2, 64, 1, 1),
      %heads.0.2.bias : Float(2),
      %heads.1.0.weight : Float(64, 256, 3, 3),
      %heads.1.0.bias : Float(64),
      %heads.1.2.weight : Float(1, 64, 1, 1),
      %heads.1.2.bias : Float(1),
      %heads.2.0.weight : Float(64, 256, 3, 3),
      %heads.2.0.bias : Float(64),
      %heads.2.2.weight : Float(2, 64, 1, 1),
      %heads.2.2.bias : Float(2)):
  %13 : Float(1, 64, 128, 128) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input.1, %heads.0.0.weight, %heads.0.0.bias) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0
  %14 : Float(1, 64, 128, 128) = onnx::Relu(%13) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:912:0
  %15 : Float(1, 2, 128, 128) = onnx::Conv[dilations=[1, 1], group

In [16]:
input_hg = torch.randn(1, 256, 128, 128, device=device, requires_grad=True)
torch.onnx.export(hourglass.score[1], input_hg, "hg_score_2.onnx", 
                  opset_version=10, verbose=True, export_params=True,
                  do_constant_folding=True)

graph(%input.1 : Float(1, 256, 128, 128),
      %heads.0.0.weight : Float(64, 256, 3, 3),
      %heads.0.0.bias : Float(64),
      %heads.0.2.weight : Float(2, 64, 1, 1),
      %heads.0.2.bias : Float(2),
      %heads.1.0.weight : Float(64, 256, 3, 3),
      %heads.1.0.bias : Float(64),
      %heads.1.2.weight : Float(1, 64, 1, 1),
      %heads.1.2.bias : Float(1),
      %heads.2.0.weight : Float(64, 256, 3, 3),
      %heads.2.0.bias : Float(64),
      %heads.2.2.weight : Float(2, 64, 1, 1),
      %heads.2.2.bias : Float(2)):
  %13 : Float(1, 64, 128, 128) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input.1, %heads.0.0.weight, %heads.0.0.bias) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0
  %14 : Float(1, 64, 128, 128) = onnx::Relu(%13) # /home/fcr/anaconda3/envs/lcnn/lib/python3.6/site-packages/torch/nn/functional.py:912:0
  %15 : Float(1, 2, 128, 128) = onnx::Conv[dilations=[1, 1], group

In [8]:
# torch.onnx.export(multitask, input_dict, "multitask.onnx", 
#                   opset_version=10, verbose=True, export_params=True,
#                   do_constant_folding=True)

## Exporting Linevectorizer

Exporting layers used in module:
1. fc1 (nn.Conv2d(256, M.dim_loi, 1)
2. pooling (nn.Sequential(...) or nn.MaxPool1d)
3. fc2 (nn.Sequential(...))

On analysis:
```
fc1 expected input: torch.Size([1, 256, 128, 128])
pooling expected input: torch.Size([10731, 128, 32])
fc2 expected input: torch.Size([10731, 1032])
```

In [9]:
# multi_results = multitask(input_dict)
# input_fc1 = multi_results["feature"]
# torch.onnx.export(linevec.fc1, input_fc1, "linevec_fc1.onnx", 
#                   opset_version=11, verbose=True)

In [10]:
# input_pool = torch.randn(10731, 128, 32, requires_grad=True)
# torch.onnx.export(linevec.pooling, input_pool, "linevec_pool.onnx", 
#                   opset_version=10, verbose=True, export_params=True)

In [11]:
# input_fc2 = torch.randn(10731, 1032, device=device, requires_grad=True)
# torch.onnx.export(linevec.fc2, input_fc2, "linevec_fc2.onnx", 
#                   opset_version=11, verbose=True)

## Testing ONNX conversions

In [1]:
import onnx
import onnxruntime

In [2]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad \
               else tensor.cpu().numpy()

def test_onnx_model(name, torch_model, shape):
    input_rand = torch.randn(shape, device=device, requires_grad=True)
    model_out = torch_model(input_rand)
    ort_session = onnxruntime.InferenceSession(name)
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_rand)}
    ort_outs = ort_session.run(None, ort_inputs)
    
    np.testing.assert_allclose(to_numpy(model_out), ort_outs[0], 
                           rtol=1e-03, atol=1e-05)
    print("Similarity test passed")

Testing fc1

In [3]:
test_onnx_model('./linevec_fc1_opt/linevec_fc1.onnx', linevec.fc1, (1, 256, 128, 128))

NameError: name 'linevec' is not defined

Testing pooling

In [17]:
test_onnx_model('./linevec_pool_opt/linevec_pool.onnx', linevec.pooling, (10731, 128, 32))

Similarity test passed


Testing fc2

In [18]:
test_onnx_model('./linevec_fc2_opt/linevec_fc2.onnx', linevec.fc2, (10731, 1032))

Similarity test passed


Testing Hourglass Net

In [22]:
ort_session = onnxruntime.InferenceSession('./hg_opt/hg.onnx')
hg_inputs = ort_session.get_inputs()

In [23]:
[(i.name, i.shape) for i in hg_inputs]

[('input.1', [1, 3, 512, 512])]

In [24]:
[(i.name, i.shape) for i in ort_session.get_outputs()]

[('1306', [1, 5, 128, 128]),
 ('1021', [1, 5, 128, 128]),
 ('1296', [1, 256, 128, 128])]

In [25]:
ort_inputs = {
    hg_inputs[0].name: to_numpy(input_dict["image"])
}
ort_outs = ort_session.run(None, ort_inputs)

In [40]:
len(ort_outs)

3

In [41]:
ort_outs[0].shape, ort_outs[1].shape, ort_outs[2].shape

((1, 5, 128, 128), (1, 5, 128, 128), (1, 256, 128, 128))

Testing Multitask

Recall input: <br/>
```
input_dict = {
    "image": image.to(device),
    "meta": [
        {
            "junc": torch.zeros(1, 2).to(device),
            "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
            "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
            "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
        }
    ],
    "target": {
        "jmap": torch.zeros([1, 1, 128, 128]).to(device),
        "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
    },
    "mode": "testing",
}
```

Based on graph analysis, the image input, meta inputs, and target inputs made it into the onnx trace.

In [14]:
multi_out = multitask(input_dict)

In [5]:
ort_session = onnxruntime.InferenceSession('multitask_opt/multitask.onnx')
mul_inputs = ort_session.get_inputs()

In [6]:
[(i.name, i.shape) for i in mul_inputs]

[('input.1', [1, 3, 512, 512]), ('5', [1, 1, 128, 128])]

In [7]:
mul_outputs = ort_session.get_outputs()
[(i.name, i.shape) for i in mul_outputs]

[('1302', [1, 256, 128, 128]),
 ('1370', [1, 1, 128, 128]),
 ('1371', [1, 128, 128]),
 ('1375', [1, 1, 2, 128, 128])]

In [17]:
ort_session = onnxruntime.InferenceSession('multitask.onnx')
mul_inputs = ort_session.get_inputs()
ort_inputs = {
    mul_inputs[0].name: to_numpy(input_dict["image"]),
    # mul_inputs[1].name: to_numpy(torch.zeros(1, 2).to(device)),
    # mul_inputs[2].name: to_numpy(torch.zeros(1, dtype=torch.uint8).to(device)),
    # mul_inputs[3].name: to_numpy(torch.zeros(2, 2, dtype=torch.uint8).to(device)),
    # mul_inputs[4].name: to_numpy(torch.zeros(2, 2, dtype=torch.uint8).to(device)),
    mul_inputs[1].name: to_numpy(torch.zeros([1, 1, 128, 128]).to(device)),
    # mul_inputs[6].name: to_numpy(torch.zeros([1, 1, 2, 128, 128]).to(device)),
}
ort_outs = ort_session.run(None, ort_inputs)

In [18]:
len(ort_outs)

4

In [19]:
feats = multi_out['feature']
jmap = multi_out['preds']['jmap']
lmap = multi_out['preds']['lmap']
joff = multi_out['preds']['joff']

In [20]:
np.testing.assert_allclose(to_numpy(feats), ort_outs[0], rtol=1e-03, atol=1e-05)

In [21]:
np.testing.assert_allclose(to_numpy(jmap), ort_outs[1], rtol=1e-03, atol=1e-05)

In [22]:
np.testing.assert_allclose(to_numpy(lmap), ort_outs[2], rtol=1e-03, atol=1e-05)

In [23]:
np.testing.assert_allclose(to_numpy(joff), ort_outs[3], rtol=1e-03, atol=1e-05)

Similarity test passed for multitasking