In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np

In [20]:
class NetA(nn.Module):
    def __init__(self, *args, **kwargs):
        super(NetA, self).__init__(*args, **kwargs)
        self.conv2d = nn.Conv2d(6,3,3,1)
    
    def forward(self, x):
        return self.conv2d(x)

In [21]:
class NetB(nn.Module):
    def __init__(self, *args, **kwargs):
        super(NetB, self).__init__(*args, **kwargs)
    def forward(self, x, w, b=None):
        x = F.conv2d(x, w, 
                     bias = b, 
                     stride=1, 
                     padding=(w.size()[2] // 2, w.size()[3] // 2),
                     groups = w.size()[0])
        
        return x

In [22]:
net_a = NetA()

In [26]:
inp = torch.rand(1,6,7,7) # 1, 6, 7, 7

In [78]:
inp.size()

torch.Size([1, 6, 7, 7])

In [31]:
w = net_a(inp) # 1, 3, 5, 5

In [32]:
b_input = torch.rand(1,3,25,25)

In [33]:
net_b = NetB()

In [34]:
output = net_b(b_input, w)

In [35]:
output.shape

torch.Size([1, 1, 25, 25])

In [39]:
# export net_a onnx

torch.onnx.export(net_a,
                 inp,
                  "weight.onnx",
                 export_params = True,
                 input_names = ["data_a"],
                  output_names = ["weight"],
                  opset_version = 9,
                  verbose = True
                 )

graph(%data_a : Float(1, 6, 7, 7, strides=[294, 49, 7, 1], requires_grad=0, device=cpu),
      %conv2d.weight : Float(3, 6, 3, 3, strides=[54, 9, 3, 1], requires_grad=1, device=cpu),
      %conv2d.bias : Float(3, strides=[1], requires_grad=1, device=cpu)):
  %weight : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%data_a, %conv2d.weight, %conv2d.bias) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/modules/conv.py:395:0
  return (%weight)



In [41]:
# export  net_b onnx
torch.onnx.export(net_b,
                 (b_input,w),
                 "second.onnx",
                  export_params = True,
                 input_names = ["data_b", "weight"],
                  output_names = ["result"],
                  opset_version = 9,
                  verbose = True)

graph(%data_b : Float(1, 3, 25, 25, strides=[1875, 625, 25, 1], requires_grad=0, device=cpu),
      %weight : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu)):
  %result : Float(1, 1, 25, 25, strides=[625, 625, 25, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%data_b, %weight) # <ipython-input-21-39130c03685a>:5:0
  return (%result)



In [42]:
# export net_a weights

In [49]:
a_nhwc_weight = net_a.conv2d.weight.detach().permute(0,2,3,1).flatten().numpy()

In [50]:
a_bias = net_a.conv2d.bias.detach().flatten().numpy()

In [57]:
np.concatenate((a_nhwc_weight,a_bias)).tofile("./a_weight.bin")

In [71]:
a_input = inp.permute(0,2,3,1).flatten().numpy()

In [60]:
w.permute(0,2,3,1).flatten()

tensor([ 0.4684,  0.3819,  0.2101,  0.6987,  0.4790,  0.1251,  0.8746,  0.8312,
         0.0488,  0.7421,  0.4958, -0.0594,  0.6784,  0.5604, -0.0710,  0.4312,
         0.2661,  0.4668,  0.4594,  0.2708,  0.1293,  0.9763,  0.4617,  0.4410,
         0.8330,  0.1306,  0.1065,  0.8635,  0.3693,  0.0995,  0.3490,  0.2730,
        -0.1319,  0.5298,  0.4553,  0.0352,  0.3478,  1.0269,  0.2782,  0.5578,
         0.5782, -0.1251,  0.8220,  0.3656,  0.1171,  0.5863,  0.0890,  0.0046,
         0.7409,  0.3795,  0.1152,  0.6472,  0.3727,  0.1081,  0.6846,  0.4498,
        -0.1167,  0.5570,  0.8083,  0.0186,  0.9817,  0.1612,  0.1780,  0.8837,
         0.5998, -0.3016,  0.8461,  0.3865,  0.1922,  0.8019,  0.4230,  0.3784,
         0.9045,  0.1809,  0.2087], grad_fn=<UnsafeViewBackward>)

In [61]:
# export net_b

In [73]:
nhwc_b_inut = b_input.permute(0,2,3,1).flatten().numpy()

In [64]:
b_weight = torch.ones(1,5,5,3).flatten()

In [66]:
b_bias = torch.zeros(1)

In [68]:
b_weight = torch.cat([b_weight, b_bias])

In [70]:
b_weight.numpy().tofile("./b_weight.bin")

In [72]:
w.numel()

75

In [74]:
nhwc_b_inut

array([0.11124557, 0.44433403, 0.71581495, ..., 0.68074006, 0.47817773,
       0.33011162], dtype=float32)

In [76]:
F.conv2d(b_input, b_weight[0:75].reshape(1,3,5,5), 
                     bias = b_bias, 
                     stride=1, 
                     padding=(w.size()[2] // 2, w.size()[3] // 2),
                     groups = w.size()[0])

tensor([[[[13.2039, 17.3155, 21.2508, 21.4130, 20.4439, 21.0159, 22.5136,
           23.6964, 23.1346, 23.3245, 23.9254, 23.5003, 22.6746, 23.5389,
           25.0276, 24.5513, 24.1395, 23.3710, 23.1418, 23.2321, 23.8614,
           21.6710, 22.1601, 17.1830, 11.3562],
          [18.4159, 24.2534, 30.3783, 28.9970, 28.1604, 27.8122, 28.4823,
           29.2159, 28.8141, 29.3912, 29.5282, 29.7866, 28.0009, 29.1255,
           30.7994, 30.6945, 30.9680, 31.0421, 30.9758, 30.9517, 31.5158,
           28.1393, 28.1058, 21.9343, 14.4457],
          [24.2239, 30.9239, 37.9786, 34.8496, 34.7190, 33.4696, 35.0866,
           36.9785, 37.7242, 37.6526, 37.9095, 37.7166, 35.2436, 36.1735,
           37.6385, 38.4818, 38.7949, 39.4045, 39.3030, 38.9407, 38.8848,
           36.3974, 35.9715, 28.2949, 19.8906],
          [25.2019, 33.1748, 41.2283, 37.8031, 36.7260, 35.3983, 35.5188,
           36.1904, 36.6772, 36.9401, 37.9389, 38.0075, 35.6485, 36.6654,
           38.5151, 38.9053, 40.3946, 40.3

In [77]:
output

tensor([[[[ 5.4961,  6.8551,  8.3747,  8.7307,  8.5264,  9.6899,  9.5749,
            9.6173,  9.1007,  9.4329,  9.5412,  8.3283,  8.9312,  8.3385,
           10.0003, 10.5802,  9.7305,  9.6543,  9.8884,  9.9378,  9.2975,
            9.7990, 10.0263,  6.6547,  4.5719],
          [ 7.4614, 10.2890, 12.5760, 11.5146, 11.6665, 11.1385, 11.7131,
           12.2821, 12.4094, 12.6501, 11.7847, 11.4821, 10.2969, 11.3709,
           11.9165, 13.0318, 13.0865, 12.9180, 13.2160, 12.7111, 11.5527,
           12.0262, 11.1488,  8.2633,  5.7773],
          [10.3200, 13.6143, 15.2815, 14.0160, 14.6831, 14.7817, 15.0078,
           15.6305, 16.1564, 15.5633, 15.0553, 13.9882, 13.6277, 14.1233,
           15.4279, 16.0757, 16.5630, 15.5966, 17.3684, 17.0856, 15.8438,
           16.2632, 15.1762, 10.2496,  8.2386],
          [10.7912, 14.9201, 16.8964, 16.6248, 15.8041, 14.4866, 14.5298,
           16.1584, 14.5442, 15.2655, 13.7464, 13.8273, 12.8672, 14.8791,
           16.1083, 17.3394, 17.9928, 16.1