In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timeit import default_timer as timer
from torch.utils.mobile_optimizer import optimize_for_mobile

In [2]:
class TestCase3DConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_3d = nn.Sequential(
            nn.Conv3d(32, 32, 3, 1, 1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 32, 3, 1, 1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 1, 3, 1, 1),
        )
        self.conv_2d = nn.Sequential(
            nn.Conv2d(24, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )

    def forward(self, feature_3d):
        x = self.conv_3d(feature_3d)
        x = F.softmax(x.squeeze(1), dim=1)
        x = self.conv_2d(x)
        return x



In [3]:
model = TestCase3DConv().eval()
data = torch.rand(1, 32, 24, 68, 120)

In [4]:
start = timer()
model(data)
end = timer()

print(((end - start) * 1000),"ms")

1276.5387949999995 ms


In [5]:
model(data).numel()

130560

In [6]:
torch.onnx.export(
    model,
    (data,),
    "test_case_3dconv.onnx",
    verbose=True,
    input_names=["3d_feature"],
    output_names=["2d_feature"],
    )

graph(%3d_feature : Float(1:6266880, 32:195840, 24:8160, 68:120, 120:1),
      %conv_3d.0.weight : Float(32:864, 32:27, 3:9, 3:3, 3:1),
      %conv_3d.0.bias : Float(32:1),
      %conv_3d.1.weight : Float(32:1),
      %conv_3d.1.bias : Float(32:1),
      %conv_3d.1.running_mean : Float(32:1),
      %conv_3d.1.running_var : Float(32:1),
      %conv_3d.3.weight : Float(32:864, 32:27, 3:9, 3:3, 3:1),
      %conv_3d.3.bias : Float(32:1),
      %conv_3d.4.weight : Float(32:1),
      %conv_3d.4.bias : Float(32:1),
      %conv_3d.4.running_mean : Float(32:1),
      %conv_3d.4.running_var : Float(32:1),
      %conv_3d.6.weight : Float(1:864, 32:27, 3:9, 3:3, 3:1),
      %conv_3d.6.bias : Float(1:1),
      %conv_2d.0.weight : Float(16:216, 24:9, 3:3, 3:1),
      %conv_2d.0.bias : Float(16:1),
      %conv_2d.1.weight : Float(16:1),
      %conv_2d.1.bias : Float(16:1),
      %conv_2d.1.running_mean : Float(16:1),
      %conv_2d.1.running_var : Float(16:1),
      %conv_2d.3.weight : Float(16:144, 

### Conv2D模型

In [7]:
class Test2DConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_2d = nn.Sequential(
            nn.Conv2d(3,6 , 3, 1, 1),
            nn.BatchNorm2d(6),
            nn.ReLU(),
        )
        
    def forward(self,feature_2d):
        return self.conv_2d(feature_2d)

In [8]:
model = Test2DConv().eval()

In [9]:
data = torch.randn(1,3,4,4)

In [10]:
data.permute(0,2,3,1).flatten()

tensor([-0.7444,  0.2161, -0.4212,  1.4429,  1.4878,  0.3429,  0.8672,  0.3219,
         0.9948,  0.3635, -0.2095,  1.1769,  0.7711,  1.2644,  0.5776, -0.5469,
        -0.3175, -0.7794,  0.7725,  0.2944, -0.6526,  0.4034,  1.8545, -0.6677,
         0.3432,  0.0643, -0.6606, -1.8605,  0.3472, -0.0051, -0.1759, -1.0716,
        -0.0441,  0.0968, -0.0951, -1.3599, -0.2377, -0.2915, -0.0776, -0.4318,
         1.3253, -0.8543,  0.5081,  0.8726,  1.9042, -0.7956,  1.1603,  0.4063])

In [11]:
output = model(data) # 1 6 4 4

In [12]:
torch.onnx.export(
    model,
    (data,),
    "test_2dconv.onnx",
    verbose=True,
    opset_version = 9,
    input_names=["input"],
    output_names=["output"],
    )

graph(%input : Float(1:48, 3:16, 4:4, 4:1),
      %conv_2d.0.weight : Float(6:27, 3:9, 3:3, 3:1),
      %conv_2d.0.bias : Float(6:1),
      %conv_2d.1.weight : Float(6:1),
      %conv_2d.1.bias : Float(6:1),
      %conv_2d.1.running_mean : Float(6:1),
      %conv_2d.1.running_var : Float(6:1)):
  %8 : Float(1:96, 6:16, 4:4, 4:1) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input, %conv_2d.0.weight, %conv_2d.0.bias) # /opt/anaconda3/envs/d2l/lib/python3.7/site-packages/torch/nn/modules/conv.py:416:0
  %9 : Float(1:96, 6:16, 4:4, 4:1) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%8, %conv_2d.1.weight, %conv_2d.1.bias, %conv_2d.1.running_mean, %conv_2d.1.running_var) # /opt/anaconda3/envs/d2l/lib/python3.7/site-packages/torch/nn/functional.py:2016:0
  %output : Float(1:96, 6:16, 4:4, 4:1) = onnx::Relu(%9) # /opt/anaconda3/envs/d2l/lib/python3.7/site-packages/torch/nn/functional.py:1119:0
  retu

In [13]:
output.permute(0,2,3,1).flatten()

tensor([0.0000e+00, 1.3679e-01, 1.1203e-01, 3.0004e-01, 9.5306e-02, 0.0000e+00,
        0.0000e+00, 3.3896e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.9857e-01, 0.0000e+00, 0.0000e+00, 6.4937e-02, 1.3274e-01, 0.0000e+00,
        2.0866e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1403e-01, 1.3269e-01,
        0.0000e+00, 8.8072e-01, 4.1674e-01, 5.2274e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 9.4353e-01, 7.8012e-01, 6.1580e-01, 1.2396e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0682e-01, 5.4254e-02, 4.9262e-04,
        5.8872e-02, 2.9787e-02, 0.0000e+00, 4.1843e-01, 8.2621e-02, 0.0000e+00,
        4.1931e-01, 0.0000e+00, 0.0000e+00, 7.0217e-01, 5.4993e-04, 4.4035e-02,
        1.8465e-01, 0.0000e+00, 4.4121e-02, 2.0012e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 5.5886e-01, 3.6826e-01, 5.3106e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 6.7663e-02, 0.0000e+

#### 这里的模型是test_3dconv里用到的模型

In [14]:
m_model = nn.Conv3d(32, 32, 3, 1, 1)

In [15]:
p_weight = m_model.weight.detach().permute(0,2,3,4,1).flatten()

In [16]:
p_bias = m_model.bias.detach().flatten()

In [17]:
p_bias

tensor([-0.0173,  0.0239,  0.0225, -0.0118,  0.0156,  0.0337, -0.0125, -0.0091,
         0.0282, -0.0268, -0.0295,  0.0339,  0.0018,  0.0304, -0.0088, -0.0209,
         0.0327, -0.0293,  0.0073, -0.0145, -0.0020, -0.0194,  0.0272, -0.0303,
         0.0272,  0.0269,  0.0311, -0.0005,  0.0247, -0.0321,  0.0312, -0.0031])

In [18]:
p_final = torch.cat([p_weight,p_bias],axis=0)

In [19]:
p_final.numel()

27680

In [34]:
p_final.numpy().tofile("single_layer_3dconv.bin")

In [20]:
inpu = torch.rand(1, 32, 24, 68, 120)

In [23]:
#保存模型
m_model.eval()

traced_script_module = torch.jit.trace(m_model, inpu)

optimized_m = optimize_for_mobile(traced_script_module)

optimized_m.save("bytenn_3dconv_test.pt")

In [24]:
torch.onnx.export(
    m_model,
    (inpu,),
    "single_layer_conv3d_test.onnx",
    verbose=True,
    input_names=["3d_feature"],
    output_names=["2d_feature"],
    )

graph(%3d_feature : Float(1:6266880, 32:195840, 24:8160, 68:120, 120:1),
      %weight : Float(32:864, 32:27, 3:9, 3:3, 3:1),
      %bias : Float(32:1)):
  %2d_feature : Float(1:6266880, 32:195840, 24:8160, 68:120, 120:1) = onnx::Conv[dilations=[1, 1, 1], group=1, kernel_shape=[3, 3, 3], pads=[1, 1, 1, 1, 1, 1], strides=[1, 1, 1]](%3d_feature, %weight, %bias) # /opt/anaconda3/envs/d2l/lib/python3.7/site-packages/torch/nn/modules/conv.py:567:0
  return (%2d_feature)



In [25]:
begin_time = timer()
res = m_model(inpu)
end_time = timer()

In [26]:
#ms
(end_time-begin_time) * 1000 

483.13810199999807

In [27]:
res.shape

torch.Size([1, 32, 24, 68, 120])

In [28]:
to_bin = res.detach().permute(0,2,3,4,1).flatten()

In [29]:
to_bin[0:10]

tensor([ 0.0668, -0.4403, -0.2054, -0.2988, -0.1230,  0.3138, -0.3023,  0.1028,
        -0.1982, -0.5449])

In [30]:
to_bin.numpy().tofile("res.bin")

In [31]:
inpu.permute(0,2,3,4,1).flatten().numpy().tofile("inpu.bin")

In [36]:
inpu.permute(0,2,3,4,1).flatten()[0:10]

tensor([0.7651, 0.5260, 0.8956, 0.4452, 0.8368, 0.3923, 0.2981, 0.8860, 0.8949,
        0.5104])

In [33]:
to_bin[0:100]

tensor([ 6.6766e-02, -4.4028e-01, -2.0536e-01, -2.9880e-01, -1.2297e-01,
         3.1380e-01, -3.0230e-01,  1.0279e-01, -1.9824e-01, -5.4493e-01,
        -3.0706e-01,  8.1129e-02,  1.0991e-03,  8.0735e-02, -3.2745e-01,
        -1.9374e-02,  1.7550e-01, -5.6417e-01, -3.0997e-01, -1.7843e-01,
         3.5548e-01, -1.4415e-01,  5.8936e-02, -1.1469e-01,  1.6564e-02,
        -5.2975e-02, -1.1905e-01,  1.6320e-01,  1.8762e-02,  1.7630e-01,
         7.2471e-02, -6.5588e-02,  1.9695e-01, -2.5943e-01, -1.2165e-01,
        -2.0772e-01, -3.1516e-01,  1.9693e-01, -2.6290e-01,  2.0130e-01,
        -1.8074e-01, -6.7184e-01, -9.0122e-02, -1.7983e-01,  2.5916e-01,
        -1.9926e-01, -2.2750e-01, -3.4565e-01,  1.8816e-02, -3.3267e-01,
        -9.3340e-02,  1.1969e-01, -2.8708e-02, -2.6147e-02,  8.1535e-02,
         3.4648e-02,  6.4717e-02,  6.6186e-02,  4.9889e-02, -8.0315e-02,
        -8.4436e-02,  1.0755e-01, -9.1310e-02,  4.3830e-02,  3.2533e-02,
        -5.7686e-02, -2.3388e-01,  1.0065e-01, -1.1