In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

x = torch.randn(16, 3, 32, 32) # 배치 사이즈 16, 채널 3, 사이즈 32x32 랜덤 텐서

############## Train ##############

conv0 = torch.nn.Conv2d(3, 32, kernel_size=1, padding=0) # 1x1 conv
k0 = conv0.weight
b0 = conv0.bias

conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1) # 3x3 conv
k1 = conv1.weight
b1 = conv1.bias

y0 = conv0(x)
y1 = conv1(x)
output_origin = y0+y1 # 1x1 conv와 3x3 conv 각각의 branch로부터 나온 결괏값들 합함

############## Inference ##############

k0_pad = F.pad(k0, [1,1,1,1]) # 1x1 커널을 패딩해서 3x3 커널로 만듦

new_k = k0_pad+k1 # 패딩된 1x1 커널과 원래 3x3 커널을 더해서 새로운 3x3 커널 만듦
new_b = b0+b1 # 아웃풋 채널 수가 두 conv 모두 같으므로 (같아야 하고) bias는 그냥 더해주면 됨

# 새로운 커널과 bias를 이용해서 Inference때 사용할 새로운 conv 제작
# 1x1 conv + 3x3 conv -> fused 3x3 conv로부터 나온 결괏값
output_fusion = F.conv2d(input=x, weight=new_k, bias=new_b, stride=1, padding=1)

In [2]:
new_k

tensor([[[[-0.1001,  0.1565, -0.1029],
          [ 0.0247,  0.4433,  0.1309],
          [-0.1859,  0.1246,  0.1633]],

         [[-0.1495,  0.1340,  0.1003],
          [-0.0668, -0.1929,  0.0959],
          [ 0.1016,  0.0417, -0.1624]],

         [[-0.1198, -0.0966,  0.0507],
          [ 0.0199, -0.5000, -0.0058],
          [-0.0605, -0.1510, -0.0696]]],


        [[[-0.1904,  0.0736, -0.0234],
          [ 0.0851, -0.5330,  0.1326],
          [-0.0267,  0.0332, -0.0900]],

         [[ 0.0426, -0.0613,  0.0510],
          [-0.0277, -0.0714, -0.0146],
          [ 0.1321,  0.1521,  0.1501]],

         [[-0.0280, -0.0672,  0.0424],
          [-0.0800, -0.3976, -0.0798],
          [-0.0072, -0.0226, -0.0293]]],


        [[[ 0.0086,  0.1800, -0.0855],
          [-0.0326, -0.4812, -0.1867],
          [-0.1310, -0.1376, -0.1196]],

         [[ 0.1790, -0.0923,  0.1621],
          [-0.1070,  0.3512,  0.1564],
          [ 0.1729,  0.1046, -0.1352]],

         [[ 0.0691, -0.0097,  0.1074],
     

In [3]:
k0.size()

torch.Size([32, 3, 1, 1])

In [4]:
b0.size()

torch.Size([32])

In [5]:
y0.size()

torch.Size([16, 32, 32, 32])

In [6]:
k1.size()

torch.Size([32, 3, 3, 3])

In [7]:
k0_pad.size()

torch.Size([32, 3, 3, 3])

In [8]:
new_k.size()

torch.Size([32, 3, 3, 3])

In [9]:
new_b.size()

torch.Size([32])

In [10]:
output_origin.size() # branch 2개(1x1, 3x3)로부터 나온 값의 합

torch.Size([16, 32, 32, 32])

In [11]:
output_fusion.size()

torch.Size([16, 32, 32, 32])

In [12]:
output_origin-output_fusion # almost same

tensor([[[[ 2.9802e-08,  7.4506e-08,  2.3842e-07,  ...,  2.3842e-07,
            2.3842e-07,  5.9605e-08],
          [ 4.4703e-08, -1.1921e-07,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -1.1921e-07],
          [ 0.0000e+00,  5.9605e-08,  1.1921e-07,  ...,  2.9802e-08,
            2.9802e-08,  0.0000e+00],
          ...,
          [ 0.0000e+00,  2.3842e-07,  1.1921e-07,  ...,  8.9407e-08,
           -5.9605e-08,  0.0000e+00],
          [ 1.1921e-07,  7.4506e-08,  0.0000e+00,  ..., -1.1921e-07,
            2.9802e-08,  8.9407e-08],
          [-5.9605e-08,  5.9605e-08,  8.9407e-08,  ...,  2.3842e-07,
           -1.7881e-07,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  5.9605e-08,  ...,  0.0000e+00,
           -5.9605e-08,  0.0000e+00],
          [ 2.2352e-08, -2.3842e-07, -5.9605e-08,  ..., -7.4506e-09,
           -1.1921e-07, -5.9605e-08],
          [-2.3842e-07, -1.1921e-07,  0.0000e+00,  ...,  5.9605e-08,
           -1.1921e-07, -5.9605e-08],
          ...,
     

In [13]:
output_origin

tensor([[[[-1.7764e-01, -2.0330e-01,  2.5941e+00,  ...,  1.1043e+00,
            1.3402e+00,  8.2920e-01],
          [-1.9669e-01,  9.6393e-01,  1.2838e+00,  ...,  7.6413e-02,
            1.0246e+00,  1.9002e+00],
          [ 1.4694e+00,  6.2848e-01,  1.2943e+00,  ...,  2.5196e-01,
           -1.2381e-01,  1.3040e+00],
          ...,
          [ 5.2429e-01,  1.1881e+00, -7.7276e-01,  ...,  4.1324e-01,
            2.2090e-01, -3.1272e-01],
          [-9.0398e-01,  8.2959e-02,  9.2825e-01,  ...,  1.5933e+00,
           -3.9362e-01, -2.3434e-01],
          [ 5.6297e-01, -5.0293e-01,  3.7871e-01,  ...,  2.0294e+00,
            6.4556e-01, -6.8282e-01]],

         [[ 5.1293e-01,  1.2906e+00,  9.6032e-01,  ...,  1.0500e+00,
            1.9571e-01,  1.0575e+00],
          [ 1.0344e-01,  1.2015e+00,  6.1199e-01,  ...,  4.0366e-02,
            8.8140e-01, -7.0363e-01],
          [ 4.0065e-01,  6.6336e-01,  7.2762e-01,  ..., -2.8419e-01,
            1.1839e+00,  5.9885e-01],
          ...,
     

In [14]:
output_fusion

tensor([[[[-1.7764e-01, -2.0330e-01,  2.5941e+00,  ...,  1.1043e+00,
            1.3402e+00,  8.2920e-01],
          [-1.9669e-01,  9.6393e-01,  1.2838e+00,  ...,  7.6413e-02,
            1.0246e+00,  1.9002e+00],
          [ 1.4694e+00,  6.2848e-01,  1.2943e+00,  ...,  2.5196e-01,
           -1.2381e-01,  1.3040e+00],
          ...,
          [ 5.2429e-01,  1.1881e+00, -7.7276e-01,  ...,  4.1324e-01,
            2.2090e-01, -3.1272e-01],
          [-9.0398e-01,  8.2959e-02,  9.2825e-01,  ...,  1.5933e+00,
           -3.9362e-01, -2.3434e-01],
          [ 5.6297e-01, -5.0293e-01,  3.7871e-01,  ...,  2.0294e+00,
            6.4556e-01, -6.8282e-01]],

         [[ 5.1293e-01,  1.2906e+00,  9.6032e-01,  ...,  1.0500e+00,
            1.9571e-01,  1.0575e+00],
          [ 1.0344e-01,  1.2015e+00,  6.1199e-01,  ...,  4.0366e-02,
            8.8140e-01, -7.0363e-01],
          [ 4.0065e-01,  6.6336e-01,  7.2762e-01,  ..., -2.8419e-01,
            1.1839e+00,  5.9885e-01],
          ...,
     