In [1]:
from torch import nn
import torch

In [2]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(1, 3, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(1, 3, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.features3 = nn.Sequential(
            nn.Linear(10, 5),
            nn.ReLU(),
        )

        self.classifier = nn.Linear(128*128*3 + 32*32*3 + 5, 4)
        
    def forward(self, x1, x2, x3):
        x1 = self.features1(x1)
        x2 = self.features2(x2)
        x3 = self.features3(x3)

        x1 = x1.view(x1.data.size(0), -1)
        x2 = x2.view(x2.data.size(0), -1)
        x3 = x3.view(x3.data.size(0), -1)
        
        x = torch.cat((x1, x2, x3), dim=1)
        x = self.classifier(x)
        return x

In [3]:
x1 = torch.randn(1,1,256,256)
x2 = torch.randn(1,1,64,64)
x3 = torch.randn(1,10)

In [4]:
model = MyModel()

In [5]:
output = model(x1,x2,x3)

In [6]:
# export net_a onnx

torch.onnx.export(model,
                 (x1,x2,x3),
                  "demo_multiinput.onnx",
                 export_params = True,
                 input_names = ["x1","x2","x3"],
                  output_names = ["output"],
                  opset_version = 9,
                  verbose = True
                 )

graph(%x1 : Float(1, 1, 256, 256),
      %x2 : Float(1, 1, 64, 64),
      %x3 : Float(1, 10),
      %features1.0.weight : Float(3, 1, 3, 3),
      %features1.0.bias : Float(3),
      %features2.0.weight : Float(3, 1, 3, 3),
      %features2.0.bias : Float(3),
      %features3.0.weight : Float(5, 10),
      %features3.0.bias : Float(5),
      %classifier.weight : Float(4, 52229),
      %classifier.bias : Float(4)):
  %11 : Float(1, 3, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%x1, %features1.0.weight, %features1.0.bias) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/modules/conv.py:341:0
  %12 : Float(1, 3, 128, 128) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%11) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/functional.py:487:0
  %13 : Float(1, 3, 128, 128) = onnx::Relu(%12) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/functional.py:914:0
  %14

In [7]:
t_x1 = x1.permute((0,2,3,1)).flatten()
t_x2 = x2.permute((0,2,3,1)).flatten()

In [8]:
torch.cat([t_x1,t_x2,x1.flatten()]).numpy().tofile("./input.bin")

In [23]:
torch.cat([t_x1,t_x2,x1.flatten()]).numel()

135168

In [24]:
135168 * 4

540672