In [153]:
import torch
from torch import nn
# from https://arxiv.org/pdf/1606.02147

In [170]:
def expect_shape_to_match(x:torch.Tensor,shape):
    assert x.shape[1:] == shape, f"Expected : {shape} but found: {tuple(x.shape[1:])}"
class HelperModule(nn.Module):
    def __init__(self,expected_shape) -> None:
        super().__init__()
        self.expected_shape = expected_shape
    def custom_forward(self,*args,**kwargs):
        res = self.forward(*args,**kwargs)
        expect_shape_to_match(res[0] if isinstance(res,tuple) else res, self.expected_shape)
        return res 
class BottleNeck(HelperModule):
    def __init__(self,inc,outc,expected_shape,type="normal",dilation=1,p=0.1) -> None:
        super().__init__(expected_shape)
        if type == "downsampling":
            self.max_pool = nn.MaxPool2d(2,stride=2,return_indices=True)
            proj_layers = [nn.Conv2d(inc, outc, kernel_size=2,stride=2,bias=False)]
        elif type == "asymmetric": 
            proj_layers = [
                nn.Conv2d(inc, outc, kernel_size=(5,1),padding=1, bias=False),
                nn.Conv2d(inc, outc, kernel_size=(1,5),padding=1, bias=False),
            ]
        elif type in ["normal","dilated"]: 
            proj_layers = [nn.Conv2d(inc, outc, kernel_size=1, bias=False)]
        elif type == "upsampling": 
            self.main_conv1 = nn.Sequential(
                nn.Conv2d(inc, outc, kernel_size=1, bias=False),
                nn.BatchNorm2d(outc),
            )
            self.max_unpool = nn.MaxUnpool2d(2)
            proj_layers = [nn.Conv2d(inc, outc, kernel_size=1, bias=False)]
        else:
            raise Exception(f"Invalid type {type}")
        if type == "upsampling":
            conv= nn.ConvTranspose2d(outc,outc,kernel_size=2,stride=2,bias=False)
        else:
            conv= nn.Conv2d(outc,outc,kernel_size=3,dilation=dilation,padding=dilation,bias=False)
            
        self.conv_projection = nn.Sequential(
            *proj_layers,
            nn.BatchNorm2d(outc),
            nn.PReLU()
        )
        self.conv = nn.Sequential(
            conv,
            nn.BatchNorm2d(outc),
            nn.PReLU()
        )
        self.conv_expansion = nn.Sequential(
            nn.Conv2d(outc,outc,kernel_size=1),
        )
        self.regularizer = nn.Dropout2d(p=p)
        self.type = type
    def forward(self,x,indices=None,output_size=None):
        # residual
        x2 = self.conv_projection(x)
        x2 = self.conv(x2)
        x2 = self.conv_expansion(x2)
        x2 = self.regularizer(x2)
        # main
        if self.type == "upsampling":
            assert indices is not None
            # assert output_size is not None
            x1 = self.main_conv1(x)
            x1 = self.max_unpool(x1,indices=indices)
        elif self.type == "downsampling":
            x1,indices = self.max_pool(x)
            n = x2.shape[1]-x1.shape[1]
            x1 = torch.nn.functional.pad(x1,(0,0,0,0,n//2,n//2))
        else:
            x1 = x
            # print(self.max_pool)
            # x1 = self.max_pool(x)
            pass
        x = x1+x2
        return (x, indices) if self.type == "downsampling" else x
class InitialBlock(HelperModule):
    def __init__(self, inc, outc, expected_shape) -> None:
        super().__init__(expected_shape)
        self.conv = nn.Conv2d(3,outc-inc,kernel_size=3,stride=2,padding=1)
        self.max_pool = nn.MaxPool2d(2)
    def forward(self,x):
        x1 = self.conv(x)
        x2 = self.max_pool(x)
        x = torch.concat((x1,x2),dim=1)
        return x
class Upsampling(BottleNeck):
    def __init__(self, inc, outc, expected_shape) -> None:
        super().__init__(inc, outc, expected_shape, "upsampling")
class ENet(nn.Module):
    def __init__(self,C) -> None:
        super().__init__()
        self.initial = InitialBlock(3, 16, (16, 256, 256))
        
        self.bottleneck1_0 = BottleNeck(16, 64, (64,128,128),type="downsampling")
        self.bottleneck1_x = BottleNeck(64, 64, (64,128,128))
        
        self.bottleneck2_0 = BottleNeck(64, 128, (128,64,64),type="downsampling",p=0.01)
        self.bottleneck2_1 = BottleNeck(128, 128, (128,64,64))
        self.bottleneck2_2 = BottleNeck(128, 128, (128,64,64),type="dilated",dilation=2)
        self.bottleneck2_3 = BottleNeck(128, 128, (128,64,64),type="asymmetric")
        self.bottleneck2_4 = BottleNeck(128, 128, (128,64,64),type="dilated",dilation=4)
        self.bottleneck2_5 = BottleNeck(128, 128, (128,64,64))
        self.bottleneck2_6 = BottleNeck(128, 128, (128,64,64),type="dilated",dilation=8)
        self.bottleneck2_7 = BottleNeck(128, 128, (128,64,64),type="asymmetric")
        self.bottleneck2_8 = BottleNeck(128, 128, (128,64,64),type="dilated",dilation=16)

        self.bottleneck4_0 = Upsampling(128, 64, (64,128,128))
        self.bottleneck4_1 = BottleNeck(64, 64, (64,128,128))
        self.bottleneck4_2 = BottleNeck(64, 64, (64,128,128))

        self.bottleneck5_0 = Upsampling(64, 16, (16,256,256))
        self.bottleneck5_1 = BottleNeck(16, 16, (16,256,256))
        
        self.full_conv = nn.ConvTranspose2d(
            16,
            C,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=False
        )
        
        self.expected_shape = (C,512,512)
    def forward(self,x):
        bs = x.size(0)
        x = self.initial.custom_forward(x)
        x, indices1 = self.bottleneck1_0.custom_forward(x)  # downsampling
        x = self.bottleneck1_x.custom_forward(x)
        x, indices2 = self.bottleneck2_0.custom_forward(x)  # downsampling
        x = self.bottleneck2_1.custom_forward(x)
        x = self.bottleneck2_2.custom_forward(x)  # dilated 2
        x = self.bottleneck2_3.custom_forward(x)  # asymmetric 5
        x = self.bottleneck2_4.custom_forward(x)  # dilated 4
        x = self.bottleneck2_5.custom_forward(x)
        x = self.bottleneck2_6.custom_forward(x)  # dilated 8
        x = self.bottleneck2_7.custom_forward(x)  # asymmetric 5
        x = self.bottleneck2_8.custom_forward(x)  # dilated 16

        x = self.bottleneck4_0.custom_forward(x, indices2)  # upsampling
        x = self.bottleneck4_1.custom_forward(x)
        x = self.bottleneck4_2.custom_forward(x)

        x = self.bottleneck5_0.custom_forward(x, indices1)  # upsampling
        x = self.bottleneck5_1.custom_forward(x)

        x = self.full_conv(x, output_size=(bs, *self.expected_shape))
        return x
    def nparams(self):
        return sum([p.numel() for p in self.parameters()])/1e6
model = ENet(16)
model.forward(torch.empty(2,3,512,512)).shape,model.nparams()

(torch.Size([2, 16, 512, 512]), 2.166924)