The model below is converted from PyTorch's Torchvision module. The original model source can be found [here](https://github.com/pytorch/vision/tree/master/torchvision/models/segmentation/fcn.py). The conversion process follows the procedure outlined in [the PyTorch ONNX documentation](https://pytorch.org/docs/stable/onnx.html), also borrowing from the [RetinaNet conversion](../../retinanet/README.md) in this repository.

This code requires PyTorch and Torchvision to be installed.

In [1]:
from PIL import Image
import numpy as np
from onnx import numpy_helper
import os
import onnxruntime as rt
import torch
from torchvision import transforms, models
import urllib

Utility functions to save the model and test data

In [2]:
def flatten(inputs):
    return [[flatten(i) for i in inputs] if isinstance(inputs, (list, tuple)) else inputs]


def update_flatten_list(inputs, res_list):
    for i in inputs:
        res_list.append(i) if not isinstance(i, (list, tuple)) else update_flatten_list(i, res_list)
    return res_list

def full_flatten(inputs):
    inputs_flatten = flatten(inputs)
    return update_flatten_list(inputs_flatten, [])


def to_numpy(x):
    if type(x) is not np.ndarray:
        x = x.detach().cpu().numpy() if x.requires_grad else x.cpu().numpy()
    return x


def save_tensor_proto(file_path, name, data):
    tp = numpy_helper.from_array(data)
    tp.name = name

    with open(file_path, 'wb') as f:
        f.write(tp.SerializeToString())


def save_data(test_data_dir, prefix, names, data_list):
    if isinstance(data_list, torch.autograd.Variable) or isinstance(data_list, torch.Tensor):
        data_list = [data_list]
    for i, d in enumerate(data_list):
        d = d.data.cpu().numpy()
        save_tensor_proto(os.path.join(test_data_dir, '{0}_{1}.pb'.format(prefix, i)), names[i], d)


def save_model(name, model, data_dir, inputs, outputs, input_names=None, output_names=None, **kwargs):
    if hasattr(model, 'train'):
        model.train(False)
    output_dir = './'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_dir = os.path.join(output_dir, 'test_' + name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    inputs_flatten = full_flatten(inputs)
    outputs_flatten = full_flatten(outputs)
    if input_names is None:
        input_names = []
        for i, _ in enumerate(inputs_flatten):
            input_names.append('input' + str(i+1))
    else:
        np.testing.assert_equal(len(input_names), len(inputs_flatten),
                                "Number of input names provided is not equal to the number of inputs.")

    if output_names is None:
        output_names = []
        for i, _ in enumerate(outputs_flatten):
            output_names.append('output' + str(i+1))
    else:
        np.testing.assert_equal(len(output_names), len(outputs_flatten),
                                "Number of output names provided is not equal to the number of output.")

    model_path = os.path.join(output_dir, 'model.onnx')
    torch.onnx.export(model, inputs, model_path, verbose=True, input_names=input_names,
                      output_names=output_names, example_outputs=outputs, **kwargs)

    test_data_dir = os.path.join(output_dir, data_dir)
    if not os.path.exists(test_data_dir):
        os.makedirs(test_data_dir)

    save_data(test_data_dir, "input", input_names, inputs_flatten)
    save_data(test_data_dir, "output", output_names, outputs_flatten)

    return model_path, test_data_dir

Utility functions to run inference on the PyTorch and ORT models

In [3]:
def torch_inference(model, input):
    print("====== Torch Inference ======")
    output=model(input)
    return output


def ort_inference(file, inputs, outputs=None):
    print("====== ORT Inference ======")
    inputs_flatten = full_flatten(inputs)
    outputs_flatten = full_flatten(outputs)

    # Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
    # other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
    # based on the build flags) when instantiating InferenceSession.
    # For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
    # onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
    sess = rt.InferenceSession(file)
    ort_inputs = dict((sess.get_inputs()[i].name, to_numpy(input)) for i, input in enumerate(inputs_flatten))
    res = sess.run(None, ort_inputs)

    if outputs is not None:
        print("== Checking model output ==")
        [np.testing.assert_allclose(to_numpy(output), res[i], rtol=1e-03, atol=2e-04) for i, output in enumerate(outputs_flatten)]
    
    print("== Done ==")
    return res

## Step 1: Download models from PyTorch's model zoo
Memory constraints mean that only one model may be converted at a time. A boolean variable controls which one will be converted first.

In [4]:
DO_101 = True

if DO_101:
    model = models.segmentation.fcn_resnet101(pretrained=True)
else:
    model = models.segmentation.fcn_resnet50(pretrained=True)

model.eval()
model.exporting = True

## Step 2: Preprocess, run PyTorch inference on test images

In [8]:
data_dir = 'test_data_set_0'
url, filename = ("https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/fcn/dependencies/000000017968.jpg", "000000017968.jpg")
#url, filename = ("https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/fcn/dependencies/000000025205.jpg", "000000025205.jpg")
#urllib.request.urlretrieve(url, filename)

input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)
output = torch_inference(model, input_tensor)
output_tensor, aux_tensor = output['out'], output['aux']



## Step 3: Save ONNX Models
This model can take in arbitrary resolutions/batch sizes, so specify that using PyTorch's `dynamic_axes` parameter when exporting the model.

In [9]:
if DO_101:
    model_name = 'fcn_resnet101'
else:
    model_name = 'fcn_resnet50'

model_path, data_dir = save_model(
    model_name, model.cpu(),
    data_dir,
    input_tensor, [output_tensor, aux_tensor],
    input_names=['input'], output_names=['out', 'aux'],
    dynamic_axes={
        'input': {0: 'batch', 2: 'height', 3: 'width'},
        'out': {0: 'batch', 2: 'height', 3: 'width'},
        'aux': {0: 'batch', 2: 'height', 3: 'width'},
    },
    opset_version=11
)

graph(%input : Float(1:921600, 3:307200, 480:640, 640:1, requires_grad=0, device=cpu),
      %classifier.4.weight : Float(21:512, 512:1, 1:1, 1:1, requires_grad=1, device=cpu),
      %classifier.4.bias : Float(21:1, requires_grad=1, device=cpu),
      %aux_classifier.4.weight : Float(21:256, 256:1, 1:1, 1:1, requires_grad=1, device=cpu),
      %aux_classifier.4.bias : Float(21:1, requires_grad=1, device=cpu),
      %1024 : Float(64:147, 3:49, 7:7, 7:1, requires_grad=0, device=cpu),
      %1025 : Float(64:1, requires_grad=0, device=cpu),
      %1027 : Float(64:64, 64:1, 1:1, 1:1, requires_grad=0, device=cpu),
      %1028 : Float(64:1, requires_grad=0, device=cpu),
      %1030 : Float(64:576, 64:9, 3:3, 3:1, requires_grad=0, device=cpu),
      %1031 : Float(64:1, requires_grad=0, device=cpu),
      %1033 : Float(256:64, 64:1, 1:1, 1:1, requires_grad=0, device=cpu),
      %1034 : Float(256:1, requires_grad=0, device=cpu),
      %1036 : Float(256:64, 64:1, 1:1, 1:1, requires_grad=0, device

## Step 4 (optional): Test ONNX models vs. the PyTorch outputs

In [10]:
ort_inference(model_path, input_tensor.detach().cpu().numpy(), [output_tensor, aux_tensor])

== Checking model output ==
== Done ==


[array([[[[ 7.9147706 ,  7.9147706 ,  7.9147706 , ...,  7.9186316 ,
            7.9186316 ,  7.9186316 ],
          [ 7.9147706 ,  7.9147706 ,  7.9147706 , ...,  7.9186316 ,
            7.9186316 ,  7.9186316 ],
          [ 7.9147706 ,  7.9147706 ,  7.9147706 , ...,  7.9186316 ,
            7.9186316 ,  7.9186316 ],
          ...,
          [ 7.2662272 ,  7.2662272 ,  7.2662272 , ...,  6.977423  ,
            6.977423  ,  6.977423  ],
          [ 7.2662272 ,  7.2662272 ,  7.2662272 , ...,  6.977423  ,
            6.977423  ,  6.977423  ],
          [ 7.2662272 ,  7.2662272 ,  7.2662272 , ...,  6.977423  ,
            6.977423  ,  6.977423  ]],
 
         [[-1.1648259 , -1.1648259 , -1.1648259 , ..., -1.9420264 ,
           -1.9420264 , -1.9420264 ],
          [-1.1648259 , -1.1648259 , -1.1648259 , ..., -1.9420264 ,
           -1.9420264 , -1.9420264 ],
          [-1.1648259 , -1.1648259 , -1.1648259 , ..., -1.9420264 ,
           -1.9420264 , -1.9420264 ],
          ...,
          [-2