## Implementation for the paper RepVGG: Making VGG-style ConvNets Great Again
I only implemented the inference-time model for the idea of RepVGG. I wrote this code just for the beauty of some deep-learning ideas and the fun of playing Pytorch. The idea of RepVGG is presented in the following screenshot, which is taken from the original paper. Specifically, the goal is to only use 3x3 convolution kernel to represent [3x3 + 1x1 + identity], which will save computation costs. Because it puts three operations into one tensor(if we omit biases).
![RepVGG](./RepVGG.png)

In [103]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Generate the data of one image
data = torch.randn(1,2,256,256)
in_channels = 2
ou_channels = 2
kernel_size = 3

# original method
conv_layer = nn.Conv2d(in_channels, ou_channels, kernel_size, padding="same")
pointwise_layer = nn.Conv2d(in_channels, ou_channels, 1) # kernel_size=1
out1 = conv_layer(data) + pointwise_layer(data) + data

In [104]:
# only use 3x3 convolution kernel
## transform the 1*1 kernel of the pointwise layer into 3*3 kernel
pointwise_weight = F.pad(pointwise_layer.weight.data, pad=(1,1,1,1)) # 2*2*1*1 -> 2*2*3*3
conv2d_pointwise = nn.Conv2d(in_channels, ou_channels, kernel_size, padding="same")
conv2d_pointwise.weight = nn.Parameter(pointwise_weight)
conv2d_pointwise.bias = pointwise_layer.bias

## perform the identity operation via convolution of 3*3 kernels
ones = torch.unsqueeze(F.pad(torch.ones(1,1), pad=(1,1,1,1)), dim=0)  
zeros = torch.unsqueeze(torch.zeros(kernel_size,kernel_size), dim=0)
out_channel1_identity =torch.unsqueeze(torch.cat([ones,zeros], dim=0), dim=0)
out_channel2_identity = torch.unsqueeze(torch.cat([zeros,ones], dim=0), dim=0)
identity_layer_weight = torch.cat([out_channel1_identity,out_channel2_identity], dim=0)
identity_layer_bias = torch.zeros(ou_channels)

conv2d_identity_layer = nn.Conv2d(in_channels, ou_channels, kernel_size, padding="same")
conv2d_identity_layer.weight = nn.Parameter(identity_layer_weight)
conv2d_identity_layer.bias = nn.Parameter(identity_layer_bias)

out2 = conv_layer(data) + conv2d_pointwise(data) + conv2d_identity_layer(data)
torch.allclose(out1, out2)

True

In [105]:
# fusion
fusion_layer = nn.Conv2d(in_channels, ou_channels, kernel_size, padding="same")
fusion_layer.weight = nn.Parameter(conv_layer.weight.data + conv2d_pointwise.weight.data + conv2d_identity_layer.weight.data)
fusion_layer.bias = nn.Parameter(conv_layer.bias.data + conv2d_pointwise.bias.data + conv2d_identity_layer.bias.data)
out3 = fusion_layer(data)
torch.allclose(out2,out3)

False

In [106]:
torch.max(out1.detach()-out2.detach())

tensor(0.)

In [107]:
torch.max(out1.detach()-out3.detach())

tensor(9.5367e-07)

In [110]:
# compare their inference times
import time
t1 = time.time()
out1 = conv_layer(data) + pointwise_layer(data) + data # original method
t2 = time.time()
print(f"Time: {t2-t1}")

t1 = time.time()
out3 = fusion_layer(data) # fusion method
t2 = time.time()
print(f"Time: {t2-t1}")

Time: 0.006000041961669922
Time: 0.0019991397857666016
