In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [2]:
class ConvBn2d(nn.Module):
    def __init__(self, channels, kernel_size,):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size, groups=channels, padding=kernel_size//2) # Keep same spatial res
        self.bn = nn.BatchNorm2d(channels)
    
    def forward(self, x):
        return self.bn(self.conv(x))


class LargeKernel(nn.Module):
    def __init__(self, channels, kernel, small_kernels=()):
        super().__init__()
        self.dw_large = ConvBn2d(channels, kernel)

        self.small_kernels = small_kernels
        for k in self.small_kernels:
            setattr(self, f"dw_small_{k}", ConvBn2d(channels, k))

    def forward(self, inp):
        outp = self.dw_large(inp)
        for k in self.small_kernels:
            outp += getattr(self, f"dw_small_{k}")(inp)
        return outp


class ReduceDimBlock(nn.Module):
    # Reudce dimension by 1
    def __init__(self, channels, kernel, small_kernels=(), activation=nn.SiLU):
        super().__init__()
        self.lk = LargeKernel(channels, kernel, small_kernels)
        self.lk_act = activation()
        self.pw2 = nn.Conv2d(channels, channels-1, 1, 1, 0)
    
    def forward(self, x):
        lk_out = self.lk_act(self.lk(x))
        pw2_out = self.pw2(lk_out)
        
        return pw2_out


class RouteBlock(nn.Module):
    def __init__(self, channels, kernel, small_kernels=(), activation=nn.SiLU) -> None:
        super().__init__()
        self.reduce1 = ReduceDimBlock(channels, kernel, small_kernels, activation) # 7 -> 6
        self.reduce2 = ReduceDimBlock(channels - 1, kernel, small_kernels, activation) # 6 -> 5
        self.reduce3 = ReduceDimBlock(channels - 2, kernel, small_kernels, activation) # 5 -> 4
        self.reduce4 = ReduceDimBlock(channels - 3, kernel, small_kernels, activation) # 4 -> 3
        self.reduce5 = ReduceDimBlock(channels - 4, kernel, small_kernels, activation) # 3 -> 2
        self.reduce6 = ReduceDimBlock(channels - 5, kernel, small_kernels, activation) # 2 -> 1
    
    def forward(self, x):
        x = self.reduce1(x)
        x = self.reduce2(x)
        x = self.reduce3(x)
        x = self.reduce4(x)
        x = self.reduce5(x)
        x = self.reduce6(x)
        return x


class LargeKernelNet(nn.Module):
    def __init__(self, channels, kernels=(13, 9, 7, 5)) -> None:
        super().__init__()
        self.large_kernels = kernels
        for k in self.large_kernels:
            setattr(self, f'route_k{k}', RouteBlock(channels, 
                                                        kernel=k, 
                                                        small_kernels=(3,)))

        self.conv_linear = nn.Conv2d(len(kernels), 1, 1, 1, 0, bias=False) # Linear transformation
    
    def forward(self, x):
        outputs = []
        for k in self.large_kernels:
            outp = getattr(self, f"route_k{k}")(x)
            outputs.append(outp) # Waiting for concatenation
        
        y = torch.cat(outputs, dim=1)
        y = self.conv_linear(y)
        
        return y

In [3]:
sample = torch.rand((4, 7, 1024, 1024))
model = LargeKernelNet(7, kernels=(9, 7, 5, 3))
res = model(sample)

In [4]:
summary(model)

Layer (type:depth-idx)                   Param #
├─RouteBlock: 1-1                        --
|    └─ReduceDimBlock: 2-1               --
|    |    └─LargeKernel: 3-1             672
|    |    └─SiLU: 3-2                    --
|    |    └─Conv2d: 3-3                  48
|    └─ReduceDimBlock: 2-2               --
|    |    └─LargeKernel: 3-4             576
|    |    └─SiLU: 3-5                    --
|    |    └─Conv2d: 3-6                  35
|    └─ReduceDimBlock: 2-3               --
|    |    └─LargeKernel: 3-7             480
|    |    └─SiLU: 3-8                    --
|    |    └─Conv2d: 3-9                  24
|    └─ReduceDimBlock: 2-4               --
|    |    └─LargeKernel: 3-10            384
|    |    └─SiLU: 3-11                   --
|    |    └─Conv2d: 3-12                 15
|    └─ReduceDimBlock: 2-5               --
|    |    └─LargeKernel: 3-13            288
|    |    └─SiLU: 3-14                   --
|    |    └─Conv2d: 3-15                 8
|    └─ReduceDimBlock: 

Layer (type:depth-idx)                   Param #
├─RouteBlock: 1-1                        --
|    └─ReduceDimBlock: 2-1               --
|    |    └─LargeKernel: 3-1             672
|    |    └─SiLU: 3-2                    --
|    |    └─Conv2d: 3-3                  48
|    └─ReduceDimBlock: 2-2               --
|    |    └─LargeKernel: 3-4             576
|    |    └─SiLU: 3-5                    --
|    |    └─Conv2d: 3-6                  35
|    └─ReduceDimBlock: 2-3               --
|    |    └─LargeKernel: 3-7             480
|    |    └─SiLU: 3-8                    --
|    |    └─Conv2d: 3-9                  24
|    └─ReduceDimBlock: 2-4               --
|    |    └─LargeKernel: 3-10            384
|    |    └─SiLU: 3-11                   --
|    |    └─Conv2d: 3-12                 15
|    └─ReduceDimBlock: 2-5               --
|    |    └─LargeKernel: 3-13            288
|    |    └─SiLU: 3-14                   --
|    |    └─Conv2d: 3-15                 8
|    └─ReduceDimBlock: 

In [5]:
res.size()

torch.Size([4, 1, 1024, 1024])