**卷积残差模块算子融合**
**要复现的结构图**
![image.png](../add_pic/1.png)

In [12]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import time

In [13]:
# 参数设置
in_channels = 2
out_channels = 2
kernel_size = 3
w = 9
h = 9

In [14]:
# res_block = 3*3 conv + 1*1 conv + input
x = torch.ones(1, in_channels, w, h)

In [15]:
# 方法1：原生写法
t1 = time.time()
conv_2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_pointwise = nn.Conv2d(in_channels, out_channels, 1)
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
t2 = time.time()

result1

tensor([[[[1.3271, 1.3578, 1.3578, 1.3578, 1.3578, 1.3578, 1.3578, 1.3578,
           1.4942],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.7121, 1.6625, 1.6625, 1.6625, 1.6625, 1.6625, 1.6625, 1.6625,
           1.3619]],

         [[0.6278, 0.8810, 0.8810, 0.8810, 0.8810, 0.8810, 0.8810, 0.8810,
           0.8605],
          [0.6599, 1.1470, 1.1470, 1.1470, 1.147

In [16]:
# 方法2：算子融合
# 把point_wise卷积和x本身都写成3*3卷积的形式
# 最终将三个卷积写成一个卷积

# (1)改造
# 2*2*1*1->2*2*3*3
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1, 1, 1, 1, 0, 0, 0, 0])
conv_2d_for_pointwise = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight)
conv_2d_for_pointwise.bias = nn.Parameter(conv_2d_pointwise.bias)

# 2*2*3*3
zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)
identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0)
identity_to_conv_bias = torch.zeros(out_channels)
conv_2d_for_identity = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)

result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_identity(x)

print(result2)
print(torch.allclose(result1, result2))

# （2）融合
t3 = time.time()
conv_2d_for_fusion = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data + pointwise_to_conv_weight.data + identity_to_conv_weight.data)
conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data + conv_2d_pointwise.bias.data + identity_to_conv_bias.data)
t4 = time.time()

result3 = conv_2d_for_fusion(x)
print(result3)
print(torch.allclose(result2, result3))

tensor([[[[1.3271, 1.3578, 1.3578, 1.3578, 1.3578, 1.3578, 1.3578, 1.3578,
           1.4942],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.5642, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190, 1.5190,
           1.4437],
          [1.7121, 1.6625, 1.6625, 1.6625, 1.6625, 1.6625, 1.6625, 1.6625,
           1.3619]],

         [[0.6278, 0.8810, 0.8810, 0.8810, 0.8810, 0.8810, 0.8810, 0.8810,
           0.8605],
          [0.6599, 1.1470, 1.1470, 1.1470, 1.147

In [17]:
print("原生写法耗时：", t2 - t1)
print("算子融合耗时：", t4 - t3)

原生写法耗时： 0.0010006427764892578
算子融合耗时： 0.0009989738464355469
