In [1]:
import torch

In [2]:
class ConvBlock(torch.nn.Module):
    def __init__(self, inp, out, kernel, stride, padding, bias, act):
        super().__init__()
        if act:
            self.conv_block = torch.nn.Sequential(
                torch.nn.Conv2d(inp, out, kernel, stride, padding, bias),
                torch.nn.BatchNorm2d(out),
                torch.nn.ReLU()
            )
        else:
            self.conv_block = torch.nn.Sequential(
                torch.nn.Conv2d(inp, out, kernel, stride, padding, bias),
                torch.nn.BatchNorm2d(out)
            )
    
    def forward(self, x):
        return self.conv_block(x)

In [8]:
class DeConvBlock(torch.nn.Module):
    def __init__(self, inp, out, kernel, stride, padding):
        super().__init__()
        self.conv_transpose = torch.nn.ConvTranspose2d(inp, out, kernel, stride, padding)
        self.batchnorm = torch.nn.BatchNorm2d(out)
        self.relu = torch.nn.ReLU()
        
    def forward(self, x, out):
        conv_transpose = self.conv_transpose(x, output_size=out)
        batchnorm = self.batchnorm(conv_transpose)
        return self.relu(batchnorm)

In [9]:
class Encoder(torch.nn.Module):
    """Downsampling image size to 2 times"""
    def __init__(self, inp, out):
        super().__init__()
        self.block1 = torch.nn.Sequential(
            ConvBlock(inp=inp, out=out, kernel=3, stride=2, padding=1, bias=True, act=True),
            ConvBlock(inp=out, out=out, kernel=3, stride=1, padding=1, bias=True, act=True)
        )
        
        self.block2 = torch.nn.Sequential(
            ConvBlock(inp=out, out=out, kernel=3, stride=1, padding=1, bias=True, act=True),
            ConvBlock(inp=out, out=out, kernel=3, stride=1, padding=1, bias=True, act=True)
        )
        
        self.residue = ConvBlock(inp=inp, out=out, kernel=3, stride=2, padding=1, bias=True, act=True)
        
    def forward(self, x):
        block1 = self.block1(x) # downsampling by 2
        residue = self.residue(x) # downsampling by 2
        block2 = self.block2(block1 + residue) # same size
        return block1 + block2 # downsampling by 2

In [96]:
class DeCoder(torch.nn.Module):
    """Upsampling image size to 2 times"""
    def __init__(self, inp, out):
        super().__init__()
        self.conv_block1 = ConvBlock(inp, inp//4, kernel=1, stride=1, padding=0, bias=True, act=True)
        self.de_conv_block1 = DeConvBlock(inp=inp//4, out=inp//4, kernel=3, stride=2, padding=1)
        self.conv_block2 = ConvBlock(inp//4, out, kernel=1, stride=1, padding=0, bias=True, act=True)
        
    def forward(self, x, output_size):
        conv_block1 = self.conv_block1(x) # same size
        de_conv_block1 = self.de_conv_block1(conv_block1, out=output_size) # upsampling by 2
        conv_block2 = self.conv_block2(de_conv_block1) # same size
        return conv_block2

In [113]:
class LinkNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.init_conv = ConvBlock(inp=3, out=64, kernel=7, stride=2, padding=3, bias=True, act=True)
        self.init_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.encoder1 = Encoder(inp=64, out=64)
        self.encoder2 = Encoder(inp=64, out=128)
        self.encoder3 = Encoder(inp=128, out=256)
        self.encoder4 = Encoder(inp=256, out=512)
        
        self.decoder4 = DeCoder(inp=512, out=256)
        self.decoder3 = DeCoder(inp=256, out=128)
        self.decoder2 = DeCoder(inp=128, out=64)
        self.decoder1 = DeCoder(inp=64, out=64)
        
        self.final_deconv1 = DeConvBlock(inp=64, out=32, kernel=3, stride=2, padding=1)
        self.final_conv = ConvBlock(inp=32, out=32, kernel=3, stride=1, padding=1, bias=True, act=True)
        self.final_deconv2 = DeConvBlock(inp=32, out=2, kernel=2, stride=2, padding=0)
        
    def forward(self, x):
        init_conv = self.init_conv(x) # downsampling by 2
        init_maxpool = self.init_maxpool(init_conv) # downsampling by 2
        
        e1 = self.encoder1(init_maxpool) # downsampling by 2
        e2 = self.encoder2(e1) # downsampling by 2
        e3 = self.encoder3(e2) # downsampling by 2
        e4 = self.encoder4(e3) # downsampling by 2

        d4 = self.decoder4(e4, e3.size()) + e3 # upsampling by 2
        d3 = self.decoder3(d4, e2.size()) + e2 # upsampling by 2
        d2 = self.decoder2(d3, e1.size()) + e1 # upsampling by 2
        d1 = self.decoder1(d2, init_maxpool.size()) # upsampling by 2
        
        final_deconv1 = self.final_deconv1(d1, init_conv.size()) # upsampling by 2
        final_conv = self.final_conv(final_deconv1) # same size
        final_deconv2 = self.final_deconv2(final_conv, x.size()) # upsampling by 2
        
        return final_deconv2

In [115]:
model = LinkNet()
x = torch.Tensor(1, 3, 256, 256)
y = model(x)
y.shape

torch.Size([1, 2, 256, 256])