In [None]:
import torch.nn as nn

class SLIM_CONVOLUTION_1D(nn.Module):
    def __init__(self,
                 in_channels:int,
                 reduction:int,
                 kernel_size:int,
                 stride:int=1,
                 padding:int=0,
                 dilation:int=1,
                 groups:int=1,
                 convolution_type:str="CLASSIC",
                 rank:int=1,
                 bias:bool=False,
                 *args, **kwargs) -> None:
        super().__init__()

        __CONVOLUTIONS = {"DSC": DEPTHWISE_SEPERABLE_CONVOLUTION_1D,
                          "LRC": LOW_RANK_CONVOLUTION_1D,
                          "FUSED": FUSED_CONVOLUTION_BN_1D,
                          "CLASSIC": CONVOLUTION_1D}

        self.TYPE_C = convolution_type
        if self.TYPE_C not in ["DSC", "LRC", "FUSED", "CLASSIC"]: raise Exception("Only implemented for [DSC, LRC, FUSED, CLASSIC]")
        
        self.INPUT_CHANNEL = in_channels
        self.REDUCTION = reduction
        self.KERNEL_SIZE = kernel_size
        self.RANK = rank
        self.BIAS = bias
        self.STRIDE = stride
        self.DILATION = dilation
        self.PADDING = padding
        self.GROUPS = groups


        self.SQUEEZE_EXCITE = nn.Sequential(nn.AdaptiveAvgPool1d(1),
                                            __CONVOLUTIONS[self.TYPE_C](in_channels=self.INPUT_CHANNEL,
                                                                        out_channels=int(self.INPUT_CHANNEL/2),
                                                                        kernel_size=1,
                                                                        rank=self.RANK,
                                                                        **kwargs),
                                            nn.ReLU(),
                                            __CONVOLUTIONS[self.TYPE_C](in_channels=int(self.INPUT_CHANNEL/2),
                                                                        out_channels=self.INPUT_CHANNEL,
                                                                        kernel_size=1,
                                                                        rank=self.RANK,
                                                                        **kwargs))
        
        self.BRANCH1_CONV = __CONVOLUTIONS[self.TYPE_C](in_channels=int(self.INPUT_CHANNEL/2),
                                                        out_channels=int(self.INPUT_CHANNEL/2),
                                                        kernel_size=self.KERNEL_SIZE,
                                                        rank=self.RANK,
                                                        bias=self.BIAS,
                                                        stride=self.STRIDE,
                                                        padding=self.PADDING,
                                                        dilation=self.DILATION,
                                                        groups=self.GROUPS,
                                                        **kwargs)

        self.BRANCH2_CONV = nn.Sequential(__CONVOLUTIONS[self.TYPE_C](in_channels=int(self.INPUT_CHANNEL/2),
                                                                      out_channels=int(self.INPUT_CHANNEL/2),
                                                                      kernel_size=1,
                                                                      rank=self.RANK,
                                                                      bias=self.BIAS,
                                                                      stride=self.STRIDE,
                                                                      padding=self.PADDING,
                                                                      dilation=self.DILATION,
                                                                      groups=self.GROUPS,
                                                                      **kwargs),
                                          __CONVOLUTIONS[self.TYPE_C](in_channels=int(self.INPUT_CHANNEL/2),
                                                                      out_channels=int(self.INPUT_CHANNEL/self.REDUCTION),
                                                                      kernel_size=self.KERNEL_SIZE,
                                                                      rank=self.RANK,
                                                                      bias=self.BIAS,
                                                                      stride=self.STRIDE,
                                                                      padding=self.PADDING,
                                                                      dilation=self.DILATION,
                                                                      groups=self.GROUPS,
                                                                      **kwargs))
         
    def forward(self, INPUT):
        if INPUT.shape[1]%2 != 0:
            raise Exception("Channel must be even")

        W = self.SQUEEZE_EXCITE(INPUT)
        W_F = W.flip(1)
        
        X1 = W*INPUT
        O1 = X1.split(int(X1.shape[1]/2), dim=1)
        O1 = torch.add(O1[0], O1[1])
        O1 = self.BRANCH1_CONV(O1)
        
        X2 = W_F*INPUT
        O2 = X2.split(int(X2.shape[1]/2), dim=1)
        O2 = torch.add(O2[0], O2[1])
        O2 = self.BRANCH2_CONV(O2)
        
        return torch.cat([O1, O2], dim=1)