In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
from colorama import init, Fore
init(autoreset=True)

In [2]:
class ResNetBasicblock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, Ifbias=False, stride=1, downsample=None):
        super(ResNetBasicblock, self).__init__()

        self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=Ifbias)
        self.bn_a = nn.BatchNorm2d(planes)

        self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=Ifbias)
        self.bn_b = nn.BatchNorm2d(planes)

        self.downsample = downsample

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.kernel_size != (1, 1):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        residual = x

        basicblock = self.conv_a(x)
        basicblock = self.bn_a(basicblock)
        basicblock = F.relu(basicblock, inplace=True)

        basicblock = self.conv_b(basicblock)
        basicblock = self.bn_b(basicblock)

        if self.downsample is not None:
            residual = self.downsample(x)

        return F.relu(residual + basicblock, inplace=True)
    
print("The ResNetBasicblock class is defined.")


The ResNetBasicblock class is defined.


In [3]:
def ST_Structural_Transformation(BasicBlock):
    def Conv_BN(Conv, BN):
        gamma = BN.weight
        std = (BN.running_var + BN.eps).sqrt()
        kernel = Conv.weight
        if Conv.bias is not None:
            b1 = torch.matmul(torch.eye(kernel.shape[0]).to(kernel.device)*(gamma / std), Conv.bias).to(kernel.device)
            return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), BN.bias - BN.running_mean* gamma / std + b1
        else:
            return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), BN.bias - BN.running_mean * gamma / std
    a_inplanes = BasicBlock.conv_a.in_channels
    a_planes = BasicBlock.conv_a.out_channels
    a_stride = BasicBlock.conv_a.stride
    b_inplanes = BasicBlock.conv_b.in_channels
    b_planes = BasicBlock.conv_b.out_channels
    b_stride = BasicBlock.conv_b.stride
    new_conv_a_weight, new_conv_a_bias = Conv_BN(BasicBlock.conv_a, BasicBlock.bn_a)
    new_conv_b_weight, new_conv_b_bias = Conv_BN(BasicBlock.conv_b, BasicBlock.bn_b)
    BasicBlock.conv_a = nn.Conv2d(a_inplanes, a_planes, kernel_size=3, stride=a_stride, padding=1, bias=True)
    BasicBlock.conv_b = nn.Conv2d(b_inplanes, b_planes, kernel_size=3, stride=b_stride, padding=1, bias=True)
    BasicBlock.conv_a.weight.data = new_conv_a_weight
    BasicBlock.conv_a.bias.data = new_conv_a_bias
    BasicBlock.conv_b.weight.data = new_conv_b_weight
    BasicBlock.conv_b.bias.data = new_conv_b_bias
    return BasicBlock

def SE_Structural_Expansion(BasicBlock):
    class LDE(nn.Module):
        def __init__(self, BasicBlock):
            super(LDE, self).__init__()
            

            # The new pathway
            self.conv_a_2 = nn.Conv2d(BasicBlock.conv_a.in_channels, BasicBlock.conv_a.out_channels, 
                                      kernel_size=BasicBlock.conv_a.kernel_size, stride=BasicBlock.conv_a.stride, 
                                      padding=BasicBlock.conv_a.padding, bias=False)
            self.bn_a_2 = nn.BatchNorm2d(BasicBlock.conv_a.out_channels)
            self.conv_b_2 = nn.Conv2d(BasicBlock.conv_b.in_channels, BasicBlock.conv_b.out_channels, 
                                      kernel_size=BasicBlock.conv_b.kernel_size, stride=BasicBlock.conv_b.stride, 
                                      padding=BasicBlock.conv_b.padding, bias=False)
            self.bn_b_2 = nn.BatchNorm2d(BasicBlock.conv_b.out_channels)
            self.bn_a_1 = nn.BatchNorm2d(BasicBlock.conv_a.out_channels)
            self.bn_b_1 = nn.BatchNorm2d(BasicBlock.conv_b.out_channels)

            # The Fusion Selectors
            self.fs_a_1 = nn.Conv2d(BasicBlock.conv_a.out_channels, BasicBlock.conv_a.out_channels, 
                                    kernel_size=1, bias=False)
            self.fs_a_2 = nn.Conv2d(BasicBlock.conv_a.out_channels, BasicBlock.conv_a.out_channels,
                                    kernel_size=1, bias=False)
            self.fs_b_1 = nn.Conv2d(BasicBlock.conv_b.out_channels, BasicBlock.conv_b.out_channels, 
                                    kernel_size=1, bias=False)
            self.fs_b_2 = nn.Conv2d(BasicBlock.conv_b.out_channels, BasicBlock.conv_b.out_channels,
                                    kernel_size=1, bias=False)
            
            # Initialize the weights of the New Pathway and Fusion Selectors
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    if m.kernel_size != (1, 1):
                        m.weight.data.fill_(0)
                    if m.kernel_size == (1, 1):
                        feature_num = m.weight.size(0)
                        identity_mat = np.eye(feature_num, dtype=np.float32)
                        m.weight.data = torch.from_numpy(identity_mat).reshape(feature_num, feature_num, 1, 1)
                        # m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
            
            # The old pathway
            self.conv_a_1 = BasicBlock.conv_a
            self.conv_b_1 = BasicBlock.conv_b


        def forward(self, x):
            residual = x

            # Block 1
            x1up = self.bn_a_1(self.conv_a_1(x))
            x1up = self.fs_a_1(x1up)
            x1down = self.bn_a_2(self.conv_a_2(x))
            x1down = self.fs_a_2(x1down)
            x1f = x1up + x1down

            x1f = F.relu(x1f)

            # Block 2
            x2up = self.bn_b_1(self.conv_b_1(x1f))
            x2up = self.fs_b_1(x2up)
            x2down = self.bn_b_2(self.conv_b_2(x1f))
            x2down = self.fs_b_2(x2down)
            x2f = x2up + x2down

            return F.relu(residual + x2f)
    
    return LDE(BasicBlock)

def SC_Structural_Compression(LDEBolck):
    def Conv_BN(Conv, BN):
        gamma = BN.weight
        std = (BN.running_var + BN.eps).sqrt()
        kernel = Conv.weight
        if Conv.bias is not None:
            b1 = torch.matmul(torch.eye(kernel.shape[0]).to(kernel.device)*(gamma / std), Conv.bias).to(kernel.device)
            return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), BN.bias - BN.running_mean* gamma / std + b1
        else:
            return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), BN.bias - BN.running_mean * gamma / std
    def Conv_fs(k1, b1, k2):
        k = F.conv2d(k1.permute(1, 0, 2, 3), k2, padding=0).permute(1, 0, 2, 3)
        b = torch.matmul(k2.squeeze(2).squeeze(2) , b1)
        return k, b

    Original_Block = ResNetBasicblock(LDEBolck.conv_a_1.in_channels, LDEBolck.conv_a_1.out_channels, True)

    a_1_weight_temp, a_1_bias_temp = Conv_BN(LDEBolck.conv_a_1, LDEBolck.bn_a_1)
    a_1_weight, a_1_bias = Conv_fs(a_1_weight_temp, a_1_bias_temp, LDEBolck.fs_a_1.weight)
    a_2_weight_temp, a_2_bias_temp = Conv_BN(LDEBolck.conv_a_2, LDEBolck.bn_a_2)
    a_2_weight, a_2_bias = Conv_fs(a_2_weight_temp, a_2_bias_temp, LDEBolck.fs_a_2.weight)
    Original_Block.conv_a = nn.Conv2d(LDEBolck.conv_a_1.in_channels, LDEBolck.conv_a_1.out_channels,
                                      kernel_size=LDEBolck.conv_a_1.kernel_size, stride=LDEBolck.conv_a_1.stride,
                                      padding=LDEBolck.conv_a_1.padding, groups=LDEBolck.conv_a_1.groups, bias=True)
    
    Original_Block.conv_a.weight.data = sum((a_1_weight, a_2_weight))
    Original_Block.conv_a.bias.data = sum(a_1_bias,a_2_bias)
    Original_Block.bn_a = nn.BatchNorm2d(LDEBolck.conv_a_1.out_channels)
    Original_Block.bn_a.weight.data.fill_(1)
    Original_Block.bn_a.bias.data.fill_(0)


    b_1_weight_temp, b_1_bias_temp = Conv_BN(LDEBolck.conv_b_1, LDEBolck.bn_b_1)
    b_1_weight, b_1_bias = Conv_fs(b_1_weight_temp, b_1_bias_temp, LDEBolck.fs_b_1.weight)
    b_2_weight_temp, b_2_bias_temp = Conv_BN(LDEBolck.conv_b_2, LDEBolck.bn_b_2)
    b_2_weight, b_2_bias = Conv_fs(b_2_weight_temp, b_2_bias_temp, LDEBolck.fs_b_2.weight)
    Original_Block.conv_b = nn.Conv2d(LDEBolck.conv_b_1.in_channels, LDEBolck.conv_b_1.out_channels,
                                      kernel_size=LDEBolck.conv_b_1.kernel_size, stride=LDEBolck.conv_b_1.stride,
                                      padding=LDEBolck.conv_b_1.padding, groups=LDEBolck.conv_b_1.groups, bias=True)
    
    Original_Block.conv_b.weight.data = sum((b_1_weight, b_2_weight))
    Original_Block.conv_b.bias.data = sum(b_1_bias,b_2_bias)
    Original_Block.bn_b = nn.BatchNorm2d(LDEBolck.conv_b_1.out_channels)
    Original_Block.bn_b.weight.data.fill_(1)
    Original_Block.bn_b.bias.data.fill_(0)

    return Original_Block

print("Three structural reparameterization methods are implemented in this file: \n \
      ST_Structural_Transformation, SE_Structural_Expansion, and SC_Structural_Compression.")


Three structural reparameterization methods are implemented in this file: 
       ST_Structural_Transformation, SE_Structural_Expansion, and SC_Structural_Compression.


In [4]:
# Define Random Input
inplane = 64
outplane = 64
batch_size = 1
input = torch.randn(batch_size, inplane, 32, 32)

Initial_Structure = ResNetBasicblock(inplane, outplane)

output_1 = Initial_Structure(input)

# Apply ST to the initial structure
Structural_after_ST = ST_Structural_Transformation(Initial_Structure)

output_2 = Structural_after_ST(input)
print("Output loss After ST:")
print(Fore.RED+"{:.3e}".format(((output_1 - output_2)**2).sum().data.item()))
print("which is less than 1e-7, indicating that it is lossless.")

Output loss After ST:
2.644e-08
which is less than 1e-7, indicating that it is lossless.


In [5]:
# Apply SE to block
Structural_after_SE = SE_Structural_Expansion(Structural_after_ST)

output_3 = Structural_after_SE(input)
print("Output loss After SE:")
print(Fore.RED+"{:.3e}".format(((output_2 - output_3)**2).sum().data.item()))
print("which is less than 1e-7, indicating that it is lossless.")


Output loss After SE:
0.000e+00
which is less than 1e-7, indicating that it is lossless.


In [6]:
# Apply SC to the LDE block
Structural_after_SC = SC_Structural_Compression(Structural_after_SE)

output_4 = Structural_after_SC(input)

print("Output loss After SC:")
print(Fore.RED+"{:.3e}".format(((output_3 - output_4)**2).sum().data.item()))
print("which is less than 1e-7, indicating that it is lossless.")
print("Verify whether the structure after SC is the same as the original network structure.")
print(str(Initial_Structure)==str(Structural_after_SC))

Output loss After SC:
3.903e-08
which is less than 1e-7, indicating that it is lossless.
Verify whether the structure after SC is the same as the original network structure.
True
